use crate::saga_action_generic::Action;
use crate::SagaType;
use anyhow::anyhow;
use petgraph::dot;
use petgraph::graph::NodeIndex;
use petgraph::Directed;
use petgraph::Graph;
use schemars::JsonSchema;
use serde::Deserialize;
use serde::Serialize;
use std::collections::BTreeMap;
use std::collections::BTreeSet;
use std::fmt;
use std::sync::Arc;
use thiserror::Error;
use uuid::Uuid;
#[derive(
Clone,
Copy,
Deserialize,
Eq,
JsonSchema,
Ord,
PartialEq,
PartialOrd,
Serialize,
)]
#[serde(transparent)]
pub struct SagaId(pub Uuid);
NewtypeDebug! { () pub struct SagaId(Uuid); }
NewtypeDisplay! { () pub struct SagaId(Uuid); }
NewtypeFrom! { () pub struct SagaId(Uuid); }
#[derive(
Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, JsonSchema,
)]
pub struct ActionName(String);
impl ActionName {
pub fn new<S: AsRef<str>>(name: S) -> ActionName {
ActionName(name.as_ref().to_string())
}
}
impl fmt::Debug for ActionName {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_fmt(format_args!("{:?}", self.0))
}
}
impl<S> From<S> for ActionName
where
S: AsRef<str>,
{
fn from(s: S) -> Self {
ActionName::new(s)
}
}
#[derive(
Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, JsonSchema,
)]
pub struct NodeName(String);
impl NodeName {
pub fn new<S: AsRef<str>>(name: S) -> NodeName {
NodeName(name.as_ref().to_string())
}
}
impl AsRef<str> for NodeName {
fn as_ref(&self) -> &str {
&self.0
}
}
impl fmt::Debug for NodeName {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_fmt(format_args!("{:?}", self.0))
}
}
#[derive(
Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, JsonSchema,
)]
#[serde(transparent)]
pub struct SagaName(String);
NewtypeDisplay! { () pub struct SagaName(String); }
impl SagaName {
pub fn new(name: &str) -> SagaName {
SagaName(name.to_string())
}
}
impl fmt::Debug for SagaName {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_fmt(format_args!("{:?}", self.0))
}
}
#[derive(Debug)]
pub enum ActionRegistryError {
NotFound,
}
#[derive(Debug)]
pub struct ActionRegistry<UserType: SagaType> {
actions: BTreeMap<ActionName, Arc<dyn Action<UserType>>>,
}
impl<UserType: SagaType> ActionRegistry<UserType> {
pub fn new() -> ActionRegistry<UserType> {
ActionRegistry { actions: BTreeMap::new() }
}
pub fn register(&mut self, action: Arc<dyn Action<UserType>>) {
let already_inserted = self.actions.insert(action.name(), action);
assert!(already_inserted.is_none());
}
pub fn get(
&self,
name: &ActionName,
) -> Result<Arc<dyn Action<UserType>>, ActionRegistryError> {
self.actions.get(name).cloned().ok_or(ActionRegistryError::NotFound)
}
}
impl<UserType: SagaType> Default for ActionRegistry<UserType> {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct Node {
node_name: NodeName,
kind: NodeKind,
}
#[derive(Debug, Clone)]
enum NodeKind {
Action { label: String, action_name: ActionName },
Constant { value: serde_json::Value },
Subsaga { params_node_name: NodeName, dag: Dag },
}
impl Node {
pub fn action<N: AsRef<str>, L: AsRef<str>, A: SagaType>(
node_name: N,
label: L,
action: &dyn Action<A>,
) -> Node {
Node {
node_name: NodeName::new(node_name),
kind: NodeKind::Action {
label: label.as_ref().to_string(),
action_name: action.name(),
},
}
}
pub fn constant<N: AsRef<str>>(
node_name: N,
value: serde_json::Value,
) -> Node {
Node {
node_name: NodeName::new(node_name),
kind: NodeKind::Constant { value },
}
}
pub fn subsaga<N1: AsRef<str>, N2: AsRef<str>>(
node_name: N1,
dag: Dag,
params_node_name: N2,
) -> Node {
Node {
node_name: NodeName::new(node_name),
kind: NodeKind::Subsaga {
params_node_name: NodeName::new(params_node_name),
dag,
},
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub(crate) enum InternalNode {
Start { params: Arc<serde_json::Value> },
End,
Action { name: NodeName, label: String, action_name: ActionName },
Constant { name: NodeName, value: Arc<serde_json::Value> },
SubsagaStart { saga_name: SagaName, params_node_name: NodeName },
SubsagaEnd { name: NodeName },
}
impl InternalNode {
pub fn node_name(&self) -> Option<&NodeName> {
match self {
InternalNode::Start { .. }
| InternalNode::End
| InternalNode::SubsagaStart { .. } => None,
InternalNode::Action { name, .. } => Some(&name),
InternalNode::Constant { name, .. } => Some(&name),
InternalNode::SubsagaEnd { name, .. } => Some(&name),
}
}
pub fn label(&self) -> String {
match self {
InternalNode::Start { .. } => String::from("(start node)"),
InternalNode::End => String::from("(end node)"),
InternalNode::Action { label, .. } => label.clone(),
InternalNode::Constant { value, .. } => {
let value_as_json = serde_json::to_string(value)
.unwrap_or_else(|e| {
format!("(failed to serialize constant value: {:#})", e)
});
format!("(constant = {})", value_as_json)
}
InternalNode::SubsagaStart { saga_name, .. } => {
format!("(subsaga start: {:?})", saga_name)
}
InternalNode::SubsagaEnd { .. } => String::from("(subsaga end)"),
}
}
}
pub struct NodeEntry<'a> {
internal: &'a InternalNode,
index: NodeIndex,
}
impl<'a> NodeEntry<'a> {
pub fn name(&self) -> &NodeName {
self.internal.node_name().unwrap()
}
pub fn label(&self) -> String {
self.internal.label()
}
pub fn index(&self) -> NodeIndex {
self.index
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SagaDag {
pub(crate) saga_name: SagaName,
pub(crate) graph: Graph<InternalNode, ()>,
pub(crate) start_node: NodeIndex,
pub(crate) end_node: NodeIndex,
}
pub struct SagaDagIterator<'a> {
dag: &'a SagaDag,
index: NodeIndex,
}
impl<'a> Iterator for SagaDagIterator<'a> {
type Item = NodeEntry<'a>;
fn next(&mut self) -> Option<Self::Item> {
while let Some(node) = self.dag.get(self.index) {
let index = self.index;
self.index = NodeIndex::new(self.index.index() + 1);
match node {
InternalNode::Action { .. }
| InternalNode::Constant { .. }
| InternalNode::SubsagaEnd { .. } => {
return Some(NodeEntry { internal: node, index })
}
_ => (),
}
}
None
}
}
impl SagaDag {
pub fn new(dagfrag: Dag, params: serde_json::Value) -> SagaDag {
let mut graph = dagfrag.graph;
let start_node =
graph.add_node(InternalNode::Start { params: Arc::new(params) });
let end_node = graph.add_node(InternalNode::End);
for first_node in &dagfrag.first_nodes {
graph.add_edge(start_node, *first_node, ());
}
for last_node in &dagfrag.last_nodes {
graph.add_edge(*last_node, end_node, ());
}
SagaDag {
saga_name: dagfrag.saga_name,
graph: graph,
start_node,
end_node,
}
}
pub fn saga_name(&self) -> &SagaName {
&self.saga_name
}
pub(crate) fn get(&self, node_index: NodeIndex) -> Option<&InternalNode> {
self.graph.node_weight(node_index)
}
pub fn get_index(&self, name: &str) -> Result<NodeIndex, anyhow::Error> {
self.graph
.node_indices()
.find(|i| {
self.graph[*i]
.node_name()
.map(|n| n.as_ref() == name)
.unwrap_or(false)
})
.ok_or_else(|| anyhow!("saga has no node named \"{}\"", name))
}
pub fn get_nodes(&self) -> SagaDagIterator<'_> {
SagaDagIterator { dag: self, index: NodeIndex::new(0) }
}
pub fn dot(&self) -> DagDot<'_> {
DagDot(&self.graph)
}
}
pub struct DagDot<'a>(&'a Graph<InternalNode, (), Directed, u32>);
impl<'a> fmt::Display for DagDot<'a> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let config = &[dot::Config::EdgeNoLabel];
let dot = dot::Dot::with_config(&self.0, config);
write!(f, "{:?}", dot)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Dag {
saga_name: SagaName,
graph: Graph<InternalNode, ()>,
first_nodes: Vec<NodeIndex>,
last_nodes: Vec<NodeIndex>,
}
#[derive(Debug)]
pub struct DagBuilder {
saga_name: SagaName,
graph: Graph<InternalNode, ()>,
first_added: Option<Vec<NodeIndex>>,
last_added: Vec<NodeIndex>,
node_names: BTreeSet<NodeName>,
error: Option<DagBuilderErrorKind>,
}
#[derive(Clone, Debug, Eq, Error, PartialEq)]
#[error("building saga \"{saga_name}\": {kind:#}")]
pub struct DagBuilderError {
saga_name: SagaName,
#[source]
kind: DagBuilderErrorKind,
}
#[derive(Clone, Debug, Eq, Error, PartialEq)]
enum DagBuilderErrorKind {
#[error("saga must end with exactly one node")]
BadOutputNode,
#[error(
"subsaga node {0:?} has parameters that come from node {1:?}, but it \
does not depend on any such node"
)]
BadSubsagaParams(NodeName, NodeName),
#[error("name was used multiple times in the same Dag: {0:?}")]
DuplicateName(NodeName),
#[error("attempted to append 0 nodes in parallel")]
EmptyStage,
}
impl Dag {
pub fn builder(saga_name: SagaName) -> DagBuilder {
DagBuilder::new(saga_name)
}
}
impl DagBuilder {
pub fn new(saga_name: SagaName) -> DagBuilder {
DagBuilder {
saga_name,
graph: Graph::new(),
first_added: None,
last_added: vec![],
node_names: BTreeSet::new(),
error: None,
}
}
pub fn append(&mut self, user_node: Node) {
self.append_parallel(vec![user_node])
}
pub fn append_parallel(&mut self, user_nodes: Vec<Node>) {
if self.error.is_some() {
return;
}
if user_nodes.len() == 0 {
self.error = Some(DagBuilderErrorKind::EmptyStage);
return;
}
for node in &user_nodes {
if let NodeKind::Subsaga { params_node_name, .. } = &node.kind {
if !self.node_names.contains(¶ms_node_name)
&& self.error.is_none()
{
self.error = Some(DagBuilderErrorKind::BadSubsagaParams(
node.node_name.clone(),
params_node_name.clone(),
));
return;
}
}
}
for node in &user_nodes {
if !self.node_names.insert(node.node_name.clone()) {
self.error = Some(DagBuilderErrorKind::DuplicateName(
node.node_name.clone(),
));
return;
}
}
let newnodes: Vec<NodeIndex> = user_nodes
.into_iter()
.map(|user_node| self.add_node(user_node))
.collect();
if self.first_added.is_none() {
self.first_added = Some(newnodes.clone());
}
self.set_last(&newnodes);
}
fn add_node(&mut self, user_node: Node) -> NodeIndex {
match user_node.kind {
NodeKind::Action { label, action_name } => {
self.add_simple(InternalNode::Action {
name: user_node.node_name,
label,
action_name,
})
}
NodeKind::Constant { value } => {
self.add_simple(InternalNode::Constant {
name: user_node.node_name,
value: Arc::new(value),
})
}
NodeKind::Subsaga { params_node_name, dag } => {
self.add_subsaga(user_node.node_name, dag, params_node_name)
}
}
}
fn add_simple(&mut self, node: InternalNode) -> NodeIndex {
assert!(matches!(
node,
InternalNode::Constant { .. } | InternalNode::Action { .. }
));
let newnode = self.graph.add_node(node);
self.depends_on_last(newnode);
newnode
}
fn add_subsaga(
&mut self,
name: NodeName,
subsaga_dag: Dag,
params_node_name: NodeName,
) -> NodeIndex {
let node_start = InternalNode::SubsagaStart {
saga_name: subsaga_dag.saga_name.clone(),
params_node_name: NodeName::new(params_node_name),
};
let subsaga_start = self.graph.add_node(node_start);
self.depends_on_last(subsaga_start);
let subgraph = &subsaga_dag.graph;
let mut subsaga_idx_to_saga_idx = BTreeMap::new();
for child_node_index in 0..subgraph.node_count() {
let child_node_index = NodeIndex::from(child_node_index as u32);
let node = subgraph.node_weight(child_node_index).unwrap().clone();
match node {
InternalNode::Start { .. } | InternalNode::End => {
panic!("subsaga Dag contained unexpected node: {:?}", node);
}
InternalNode::Action { .. }
| InternalNode::Constant { .. }
| InternalNode::SubsagaStart { .. }
| InternalNode::SubsagaEnd { .. } => (),
};
let parent_node_index = self.graph.add_node(node);
assert!(subsaga_idx_to_saga_idx
.insert(child_node_index, parent_node_index)
.is_none());
for ancestor_child_node_index in subgraph
.neighbors_directed(child_node_index, petgraph::Incoming)
{
let ancestor_parent_node_index = subsaga_idx_to_saga_idx
.get(&ancestor_child_node_index)
.expect("graph was not a DAG");
self.graph.add_edge(
*ancestor_parent_node_index,
parent_node_index,
(),
);
}
}
for child_first_node in &subsaga_dag.first_nodes {
let parent_first_node =
subsaga_idx_to_saga_idx.get(&child_first_node).unwrap();
self.graph.add_edge(subsaga_start, *parent_first_node, ());
}
let subsaga_end =
self.graph.add_node(InternalNode::SubsagaEnd { name });
for child_last_node in &subsaga_dag.last_nodes {
let parent_last_node =
subsaga_idx_to_saga_idx.get(&child_last_node).unwrap();
self.graph.add_edge(*parent_last_node, subsaga_end, ());
}
subsaga_end
}
fn depends_on_last(&mut self, newnode: NodeIndex) {
for node in &self.last_added {
self.graph.add_edge(*node, newnode, ());
}
}
fn set_last(&mut self, nodes: &[NodeIndex]) {
self.last_added = nodes.to_vec();
}
pub fn build(self) -> Result<Dag, DagBuilderError> {
if let Some(error) = self.error {
return Err(DagBuilderError {
saga_name: self.saga_name.clone(),
kind: error,
});
}
if self.last_added.len() != 1 {
return Err(DagBuilderError {
saga_name: self.saga_name.clone(),
kind: DagBuilderErrorKind::BadOutputNode,
});
}
Ok(Dag {
saga_name: self.saga_name,
graph: self.graph,
first_nodes: self.first_added.unwrap_or_else(|| Vec::new()),
last_nodes: self.last_added,
})
}
}
#[cfg(test)]
mod test {
use super::DagBuilder;
use super::DagBuilderErrorKind;
use super::Node;
use super::NodeName;
use super::SagaName;
#[test]
fn test_saga_names_and_label() {
let mut builder = DagBuilder::new(SagaName::new("test-saga"));
builder.append(Node::constant("a", serde_json::Value::Null));
let dag = crate::SagaDag::new(
builder.build().expect("Should have built DAG"),
serde_json::Value::Null,
);
let mut nodes = dag.get_nodes();
let node = nodes.next().unwrap();
assert_eq!("a", node.name().as_ref());
assert_eq!("(constant = null)", node.label());
assert!(nodes.next().is_none());
}
#[test]
fn test_builder_bad_output_nodes() {
let builder = DagBuilder::new(SagaName::new("test-saga"));
let result = builder.build();
println!("{:?}", result);
match result {
Ok(_) => panic!("unexpected success"),
Err(error) => {
assert_eq!(error.saga_name.to_string(), "test-saga");
assert!(matches!(
error.kind,
DagBuilderErrorKind::BadOutputNode
));
assert_eq!(
error.to_string(),
"building saga \"test-saga\": saga must end with exactly \
one node"
);
}
};
let mut builder = DagBuilder::new(SagaName::new("test-saga"));
builder.append_parallel(vec![
Node::constant("a", serde_json::Value::Null),
Node::constant("b", serde_json::Value::Null),
]);
let result = builder.build();
println!("{:?}", result);
assert!(matches!(
result.unwrap_err().kind,
DagBuilderErrorKind::BadOutputNode
));
}
#[test]
fn test_builder_empty_stage() {
let mut builder = DagBuilder::new(SagaName::new("test-saga"));
builder.append_parallel(vec![]);
let result = builder.build();
println!("{:?}", result);
let error = result.unwrap_err();
assert!(matches!(error.kind, DagBuilderErrorKind::EmptyStage));
assert_eq!(
error.to_string(),
"building saga \"test-saga\": attempted to append 0 nodes in \
parallel"
);
}
#[test]
fn test_builder_duplicate_names() {
let mut builder = DagBuilder::new(SagaName::new("test-saga"));
builder.append(Node::constant("a", serde_json::Value::Null));
builder.append(Node::constant("a", serde_json::Value::Null));
let error = builder.build().unwrap_err();
println!("{:?}", error);
assert_eq!(
error.kind,
DagBuilderErrorKind::DuplicateName(NodeName::new("a"))
);
assert_eq!(
error.to_string(),
"building saga \"test-saga\": name was used multiple times in the \
same Dag: \"a\""
);
let mut builder = DagBuilder::new(SagaName::new("test-saga"));
builder.append(Node::constant("a", serde_json::Value::Null));
builder.append_parallel(vec![
Node::constant("b", serde_json::Value::Null),
Node::constant("c", serde_json::Value::Null),
]);
builder.append(Node::constant("a", serde_json::Value::Null));
let error = builder.build().unwrap_err();
println!("{:?}", error);
assert_eq!(
error.kind,
DagBuilderErrorKind::DuplicateName(NodeName::new("a"))
);
let mut builder = DagBuilder::new(SagaName::new("test-saga"));
builder.append(Node::constant("a", serde_json::Value::Null));
builder.append_parallel(vec![
Node::constant("b", serde_json::Value::Null),
Node::constant("b", serde_json::Value::Null),
]);
let error = builder.build().unwrap_err();
println!("{:?}", error);
assert_eq!(
error.kind,
DagBuilderErrorKind::DuplicateName(NodeName::new("b"))
);
let mut inner_builder = DagBuilder::new(SagaName::new("inner-saga"));
inner_builder.append(Node::constant("a", serde_json::Value::Null));
let inner_dag = inner_builder.build().unwrap();
let mut outer_builder = DagBuilder::new(SagaName::new("outer-saga"));
outer_builder.append(Node::constant("a", serde_json::Value::Null));
outer_builder.append(Node::subsaga("b", inner_dag, "a"));
let _ = outer_builder.build().unwrap();
}
#[test]
fn test_builder_bad_subsaga_params() {
let mut subsaga_builder = DagBuilder::new(SagaName::new("inner-saga"));
subsaga_builder.append(Node::constant("a", serde_json::Value::Null));
let subsaga_dag = subsaga_builder.build().unwrap();
let mut builder = DagBuilder::new(SagaName::new("test-saga"));
builder.append(Node::constant("a", serde_json::Value::Null));
builder.append(Node::subsaga("b", subsaga_dag.clone(), "barf"));
let error = builder.build().unwrap_err();
println!("{:?}", error);
assert_eq!(
error.kind,
DagBuilderErrorKind::BadSubsagaParams(
NodeName::new("b"),
NodeName::new("barf")
)
);
assert_eq!(
error.to_string(),
"building saga \"test-saga\": subsaga node \"b\" has parameters \
that come from node \"barf\", but it does not depend on any such \
node"
);
let mut builder = DagBuilder::new(SagaName::new("test-saga"));
builder.append(Node::constant("a", serde_json::Value::Null));
builder.append(Node::subsaga("b", subsaga_dag.clone(), "b"));
let error = builder.build().unwrap_err();
println!("{:?}", error);
assert_eq!(
error.kind,
DagBuilderErrorKind::BadSubsagaParams(
NodeName::new("b"),
NodeName::new("b")
)
);
let mut builder = DagBuilder::new(SagaName::new("test-saga"));
builder.append(Node::constant("a", serde_json::Value::Null));
builder.append_parallel(vec![
Node::constant("c", serde_json::Value::Null),
Node::subsaga("b", subsaga_dag, "c"),
]);
let error = builder.build().unwrap_err();
println!("{:?}", error);
assert_eq!(
error.kind,
DagBuilderErrorKind::BadSubsagaParams(
NodeName::new("b"),
NodeName::new("c")
)
);
}
}