#![allow(clippy::large_enum_variant)]
use crate::dag::SagaDag;
use crate::saga_exec::SagaExecManager;
use crate::saga_exec::SagaExecutor;
use crate::store::SagaCachedState;
use crate::store::SagaCreateParams;
use crate::store::SecStore;
use crate::ActionError;
use crate::ActionRegistry;
use crate::SagaExecStatus;
use crate::SagaId;
use crate::SagaLog;
use crate::SagaNodeEvent;
use crate::SagaResult;
use crate::SagaType;
use anyhow::anyhow;
use anyhow::Context;
use futures::future::BoxFuture;
use futures::stream::FuturesUnordered;
use futures::FutureExt;
use futures::StreamExt;
use petgraph::graph::NodeIndex;
use schemars::JsonSchema;
use serde::Deserialize;
use serde::Serialize;
use std::collections::BTreeMap;
use std::convert::TryFrom;
use std::fmt;
use std::future::Future;
use std::num::NonZeroU32;
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
const SEC_CLIENT_MAXQ_MESSAGES: usize = 2;
const SEC_EXEC_MAXQ_MESSAGES: usize = 2;
pub fn sec(log: slog::Logger, sec_store: Arc<dyn SecStore>) -> SecClient {
let (cmd_tx, cmd_rx) = mpsc::channel(SEC_CLIENT_MAXQ_MESSAGES);
let (exec_tx, exec_rx) = mpsc::channel(SEC_EXEC_MAXQ_MESSAGES);
let task = tokio::spawn(async move {
let sec = Sec {
log,
sagas: BTreeMap::new(),
sec_store,
futures: FuturesUnordered::new(),
cmd_rx,
shutdown: false,
exec_tx,
exec_rx,
};
sec.run().await
});
SecClient { cmd_tx, task: Some(task), shutdown: false }
}
#[derive(Debug)]
pub struct SecClient {
cmd_tx: mpsc::Sender<SecClientMsg>,
task: Option<tokio::task::JoinHandle<()>>,
shutdown: bool,
}
impl SecClient {
pub async fn saga_create<UserType>(
&self,
saga_id: SagaId,
uctx: Arc<UserType::ExecContextType>,
dag: Arc<SagaDag>,
registry: Arc<ActionRegistry<UserType>>,
) -> Result<BoxFuture<'static, SagaResult>, anyhow::Error>
where
UserType: SagaType + fmt::Debug,
{
let (ack_tx, ack_rx) = oneshot::channel();
let template_params = Box::new(TemplateParamsForCreate {
dag: dag.clone(),
registry,
uctx,
}) as Box<dyn TemplateParams>;
self.sec_cmd(
ack_rx,
SecClientMsg::SagaCreate { ack_tx, saga_id, dag, template_params },
)
.await
}
pub async fn saga_resume<UserType>(
&self,
saga_id: SagaId,
uctx: Arc<UserType::ExecContextType>,
dag: serde_json::Value,
registry: Arc<ActionRegistry<UserType>>,
log_events: Vec<SagaNodeEvent>,
) -> Result<BoxFuture<'static, SagaResult>, anyhow::Error>
where
UserType: SagaType + fmt::Debug,
{
let (ack_tx, ack_rx) = oneshot::channel();
let saga_log = SagaLog::new_recover(saga_id, log_events)
.context("recovering log")?;
let dag: Arc<SagaDag> = Arc::new(
serde_json::from_value(dag)
.map_err(ActionError::new_deserialize)?,
);
let template_params = Box::new(TemplateParamsForRecover {
dag: dag.clone(),
registry,
uctx,
saga_log,
}) as Box<dyn TemplateParams>;
self.sec_cmd(
ack_rx,
SecClientMsg::SagaResume { ack_tx, saga_id, dag, template_params },
)
.await
}
pub async fn saga_start(
&self,
saga_id: SagaId,
) -> Result<(), anyhow::Error> {
let (ack_tx, ack_rx) = oneshot::channel();
self.sec_cmd(ack_rx, SecClientMsg::SagaStart { ack_tx, saga_id }).await
}
pub async fn saga_list(
&self,
marker: Option<SagaId>,
limit: NonZeroU32,
) -> Vec<SagaView> {
let (ack_tx, ack_rx) = oneshot::channel();
self.sec_cmd(ack_rx, SecClientMsg::SagaList { ack_tx, marker, limit })
.await
}
pub async fn saga_get(&self, saga_id: SagaId) -> Result<SagaView, ()> {
let (ack_tx, ack_rx) = oneshot::channel();
self.sec_cmd(ack_rx, SecClientMsg::SagaGet { ack_tx, saga_id }).await
}
pub async fn saga_inject_error(
&self,
saga_id: SagaId,
node_id: NodeIndex,
) -> Result<(), anyhow::Error> {
let (ack_tx, ack_rx) = oneshot::channel();
self.sec_cmd(
ack_rx,
SecClientMsg::SagaInjectError {
ack_tx,
saga_id,
node_id,
error_type: ErrorInjected::Fail,
},
)
.await
}
pub async fn saga_inject_repeat(
&self,
saga_id: SagaId,
node_id: NodeIndex,
repeat: RepeatInjected,
) -> Result<(), anyhow::Error> {
let (ack_tx, ack_rx) = oneshot::channel();
self.sec_cmd(
ack_rx,
SecClientMsg::SagaInjectError {
ack_tx,
saga_id,
node_id,
error_type: ErrorInjected::Repeat(repeat),
},
)
.await
}
pub async fn shutdown(mut self) {
self.shutdown = true;
self.cmd_tx.send(SecClientMsg::Shutdown).await.unwrap_or_else(
|error| panic!("failed to send message to SEC: {:#}", error),
);
self.task
.take()
.expect("missing task")
.await
.expect("failed to join on SEC task");
}
async fn sec_cmd<R>(
&self,
ack_rx: oneshot::Receiver<R>,
msg: SecClientMsg,
) -> R {
self.cmd_tx.send(msg).await.unwrap_or_else(|error| {
panic!("failed to send message to SEC: {:#}", error)
});
ack_rx.await.expect("failed to read SEC response")
}
}
impl Drop for SecClient {
fn drop(&mut self) {
if !self.shutdown {
let _ = self.cmd_tx.try_send(SecClientMsg::Shutdown);
}
}
}
#[derive(Clone, Debug, JsonSchema, Serialize)]
pub struct SagaView {
pub id: SagaId,
#[serde(skip)]
pub state: SagaStateView,
dag: serde_json::Value,
}
impl SagaView {
fn from_saga(saga: &Saga) -> impl Future<Output = Self> {
let id = saga.id;
let dag = saga.dag.clone();
let fut = SagaStateView::from_run_state(&saga.run_state);
async move {
let state = fut.await;
SagaView { id, state, dag }
}
}
pub fn serialized(&self) -> SagaSerialized {
SagaSerialized {
saga_id: self.id,
dag: self.dag.clone(),
events: self.state.status().log().events().to_vec(),
}
}
}
#[derive(Debug, Clone)]
pub enum SagaStateView {
Ready {
status: SagaExecStatus,
},
Running {
status: SagaExecStatus,
},
Done {
status: SagaExecStatus,
result: SagaResult,
},
}
impl SagaStateView {
fn from_run_state(
run_state: &SagaRunState,
) -> impl Future<Output = SagaStateView> {
enum Which {
Ready(Arc<dyn SagaExecManager>),
Running(Arc<dyn SagaExecManager>),
Done(SagaExecStatus, SagaResult),
}
let which = match run_state {
SagaRunState::Ready { exec, .. } => Which::Ready(Arc::clone(exec)),
SagaRunState::Running { exec, .. } => {
Which::Running(Arc::clone(exec))
}
SagaRunState::Done { status, result } => {
Which::Done(status.clone(), result.clone())
}
};
async move {
match which {
Which::Ready(exec) => {
SagaStateView::Ready { status: exec.status().await }
}
Which::Running(exec) => {
SagaStateView::Running { status: exec.status().await }
}
Which::Done(status, result) => {
SagaStateView::Done { status, result }
}
}
}
}
pub fn status(&self) -> &SagaExecStatus {
match self {
SagaStateView::Ready { status } => status,
SagaStateView::Running { status } => status,
SagaStateView::Done { status, .. } => status,
}
}
}
#[derive(Debug, Copy, Clone)]
pub struct RepeatInjected {
pub action: NonZeroU32,
pub undo: NonZeroU32,
}
#[derive(Debug)]
enum ErrorInjected {
Fail,
Repeat(RepeatInjected),
}
enum SecClientMsg {
SagaCreate {
ack_tx: oneshot::Sender<
Result<BoxFuture<'static, SagaResult>, anyhow::Error>,
>,
saga_id: SagaId,
template_params: Box<dyn TemplateParams>,
dag: Arc<SagaDag>,
},
SagaResume {
ack_tx: oneshot::Sender<
Result<BoxFuture<'static, SagaResult>, anyhow::Error>,
>,
saga_id: SagaId,
template_params: Box<dyn TemplateParams>,
dag: Arc<SagaDag>,
},
SagaStart {
ack_tx: oneshot::Sender<Result<(), anyhow::Error>>,
saga_id: SagaId,
},
SagaList {
ack_tx: oneshot::Sender<Vec<SagaView>>,
marker: Option<SagaId>,
limit: NonZeroU32,
},
SagaGet {
ack_tx: oneshot::Sender<Result<SagaView, ()>>,
saga_id: SagaId,
},
SagaInjectError {
ack_tx: oneshot::Sender<Result<(), anyhow::Error>>,
saga_id: SagaId,
node_id: NodeIndex,
error_type: ErrorInjected,
},
Shutdown,
}
impl fmt::Debug for SecClientMsg {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("SecClientMsg::")?;
match self {
SecClientMsg::SagaCreate {
saga_id, template_params, dag, ..
} => f
.debug_struct("SagaCreate")
.field("saga_id", saga_id)
.field("template_params", template_params)
.field("dag", dag)
.finish(),
SecClientMsg::SagaResume {
saga_id, template_params, dag, ..
} => f
.debug_struct("SagaResume")
.field("saga_id", saga_id)
.field("template_params", template_params)
.field("dag", dag)
.finish(),
SecClientMsg::SagaList { .. } => f.write_str("SagaList"),
SecClientMsg::SagaGet { saga_id, .. } => {
f.debug_struct("SagaGet").field("saga_id", saga_id).finish()
}
SecClientMsg::SagaInjectError {
saga_id,
node_id,
error_type,
..
} => f
.debug_struct("SagaInjectError")
.field("saga_id", saga_id)
.field("node_Id", node_id)
.field("error_type", error_type)
.finish(),
SecClientMsg::Shutdown { .. } => f.write_str("Shutdown"),
SecClientMsg::SagaStart { saga_id, .. } => {
f.debug_struct("SagaStart").field("saga_id", saga_id).finish()
}
}
}
}
trait TemplateParams: Send + fmt::Debug {
fn into_exec(
self: Box<Self>,
log: slog::Logger,
saga_id: SagaId,
sec_hdl: SecExecClient,
) -> Result<Arc<dyn SagaExecManager>, anyhow::Error>;
}
#[derive(Debug)]
struct TemplateParamsForCreate<UserType: SagaType + fmt::Debug> {
dag: Arc<SagaDag>,
registry: Arc<ActionRegistry<UserType>>,
uctx: Arc<UserType::ExecContextType>,
}
impl<UserType> TemplateParams for TemplateParamsForCreate<UserType>
where
UserType: SagaType + fmt::Debug,
{
fn into_exec(
self: Box<Self>,
log: slog::Logger,
saga_id: SagaId,
sec_hdl: SecExecClient,
) -> Result<Arc<dyn SagaExecManager>, anyhow::Error> {
Ok(Arc::new(SagaExecutor::new(
log,
saga_id,
self.dag,
self.registry,
self.uctx,
sec_hdl,
)?))
}
}
#[derive(Debug)]
struct TemplateParamsForRecover<UserType: SagaType + fmt::Debug> {
dag: Arc<SagaDag>,
registry: Arc<ActionRegistry<UserType>>,
uctx: Arc<UserType::ExecContextType>,
saga_log: SagaLog,
}
impl<UserType> TemplateParams for TemplateParamsForRecover<UserType>
where
UserType: SagaType + fmt::Debug,
{
fn into_exec(
self: Box<Self>,
log: slog::Logger,
saga_id: SagaId,
sec_hdl: SecExecClient,
) -> Result<Arc<dyn SagaExecManager>, anyhow::Error> {
Ok(Arc::new(SagaExecutor::new_recover(
log,
saga_id,
self.dag,
self.registry,
self.uctx,
sec_hdl,
self.saga_log,
)?))
}
}
#[derive(Debug)]
pub struct SecExecClient {
saga_id: SagaId,
exec_tx: mpsc::Sender<SecExecMsg>,
}
impl SecExecClient {
pub async fn record(&self, event: SagaNodeEvent) {
assert_eq!(event.saga_id, self.saga_id);
let (ack_tx, ack_rx) = oneshot::channel();
self.sec_send(
ack_rx,
SecExecMsg::LogEvent(SagaLogEventData { event, ack_tx }),
)
.await
}
pub async fn saga_update(&self, update: SagaCachedState) {
let (ack_tx, ack_rx) = oneshot::channel();
self.sec_send(
ack_rx,
SecExecMsg::UpdateCachedState(SagaUpdateCacheData {
ack_tx,
saga_id: self.saga_id,
updated_state: update,
}),
)
.await
}
pub async fn saga_get(&self, saga_id: SagaId) -> Result<SagaView, ()> {
let (ack_tx, ack_rx) = oneshot::channel();
self.sec_send(
ack_rx,
SecExecMsg::SagaGet(SagaGetData { ack_tx, saga_id }),
)
.await
}
async fn sec_send<T>(
&self,
ack_rx: oneshot::Receiver<T>,
msg: SecExecMsg,
) -> T {
self.exec_tx.send(msg).await.unwrap();
ack_rx.await.unwrap()
}
}
#[derive(Debug)]
enum SecExecMsg {
SagaGet(SagaGetData),
LogEvent(SagaLogEventData),
UpdateCachedState(SagaUpdateCacheData),
}
#[derive(Debug)]
struct SagaGetData {
ack_tx: oneshot::Sender<Result<SagaView, ()>>,
saga_id: SagaId,
}
#[derive(Debug)]
struct SagaLogEventData {
ack_tx: oneshot::Sender<()>,
event: SagaNodeEvent,
}
#[derive(Debug)]
struct SagaUpdateCacheData {
ack_tx: oneshot::Sender<()>,
saga_id: SagaId,
updated_state: SagaCachedState,
}
struct Sec {
log: slog::Logger,
sagas: BTreeMap<SagaId, Saga>,
sec_store: Arc<dyn SecStore>,
futures: FuturesUnordered<BoxFuture<'static, Option<SecStep>>>,
cmd_rx: mpsc::Receiver<SecClientMsg>,
exec_tx: mpsc::Sender<SecExecMsg>,
exec_rx: mpsc::Receiver<SecExecMsg>,
shutdown: bool,
}
impl Sec {
async fn run(mut self) {
info!(&self.log, "SEC running");
while !self.shutdown || !self.futures.is_empty() {
tokio::select! {
maybe_work_done = self.futures.next(),
if !self.futures.is_empty() => {
let work_result = maybe_work_done.unwrap();
if let Some(next_step) = work_result {
self.dispatch_work(next_step);
}
},
maybe_client_message = self.cmd_rx.recv() => {
assert_eq!(self.shutdown, maybe_client_message.is_none());
if let Some(client_message) = maybe_client_message {
self.dispatch_client_message(client_message);
}
},
maybe_exec_message = self.exec_rx.recv() => {
if let Some(exec_message) = maybe_exec_message {
self.dispatch_exec_message(exec_message);
}
}
}
}
}
fn client_respond<T>(
log: &slog::Logger,
ack_tx: oneshot::Sender<T>,
value: T,
) {
if ack_tx.send(value).is_err() {
warn!(log, "unexpectedly failed to send response to SEC client");
}
}
fn dispatch_work(&mut self, step: SecStep) {
match step {
SecStep::SagaInsert(insert_data) => self.saga_insert(insert_data),
SecStep::SagaDone(done_data) => self.saga_finished(done_data),
}
}
fn saga_insert(&mut self, rec: SagaInsertData) {
let saga_id = rec.saga_id;
let serialized_dag = rec.serialized_dag;
let ack_tx = rec.ack_tx;
let log = rec.log;
let exec_tx = self.exec_tx.clone();
let sec_hdl = SecExecClient { saga_id, exec_tx };
let (done_tx, done_rx) = oneshot::channel();
let maybe_exec =
rec.template_params.into_exec(log.new(o!()), saga_id, sec_hdl);
if let Err(e) = maybe_exec {
Sec::client_respond(&log, ack_tx, Err(e));
return;
}
let exec = maybe_exec.unwrap();
let run_state = Saga {
id: saga_id,
log: log.new(o!()),
dag: serialized_dag,
run_state: SagaRunState::Ready {
exec: Arc::clone(&exec),
waiter: done_tx,
},
};
if self.sagas.get(&saga_id).is_some() {
return Sec::client_respond(
&log,
ack_tx,
Err(anyhow!(
"saga_id {} cannot be inserted; already in use",
saga_id
)),
);
}
assert!(self.sagas.insert(saga_id, run_state).is_none());
if rec.autostart {
self.do_saga_start(saga_id).unwrap();
}
Sec::client_respond(
&log,
ack_tx,
Ok(async move {
done_rx.await.unwrap_or_else(|_| {
panic!("failed to wait for saga to finish")
})
}
.boxed()),
);
}
fn cmd_saga_start(
&mut self,
ack_tx: oneshot::Sender<Result<(), anyhow::Error>>,
saga_id: SagaId,
) {
let result = self.do_saga_start(saga_id);
Sec::client_respond(&self.log, ack_tx, result);
}
fn do_saga_start(&mut self, saga_id: SagaId) -> Result<(), anyhow::Error> {
let saga = self.saga_remove(saga_id)?;
let log = saga.log;
let dag = saga.dag;
let (exec, waiter) = match saga.run_state {
SagaRunState::Ready { exec, waiter } => (exec, waiter),
_ => {
return Err(anyhow!(
"saga not in \"ready\" state: {:?}",
saga_id
))
}
};
self.sagas.insert(
saga_id,
Saga {
id: saga_id,
log,
dag,
run_state: SagaRunState::Running {
exec: Arc::clone(&exec),
waiter,
},
},
);
self.futures.push(
async move {
exec.run().await;
Some(SecStep::SagaDone(SagaDoneData {
saga_id,
result: exec.result(),
status: exec.status().await,
}))
}
.boxed(),
);
Ok(())
}
fn saga_finished(&mut self, done_data: SagaDoneData) {
let saga_id = done_data.saga_id;
let saga = self.sagas.remove(&saga_id).unwrap();
info!(&saga.log, "saga finished");
if let SagaRunState::Running { waiter, .. } = saga.run_state {
Sec::client_respond(&saga.log, waiter, done_data.result.clone());
self.sagas.insert(
saga_id,
Saga {
id: saga_id,
log: saga.log,
run_state: SagaRunState::Done {
status: done_data.status,
result: done_data.result,
},
dag: saga.dag,
},
);
} else {
panic!(
"saga future completion for unexpected state: {:?}",
saga.run_state
);
}
}
fn dispatch_client_message(&mut self, message: SecClientMsg) {
match message {
SecClientMsg::SagaCreate {
ack_tx,
saga_id,
template_params,
dag,
} => {
self.cmd_saga_create(ack_tx, saga_id, template_params, dag);
}
SecClientMsg::SagaResume {
ack_tx,
saga_id,
template_params,
dag,
} => {
self.cmd_saga_resume(ack_tx, saga_id, template_params, dag);
}
SecClientMsg::SagaStart { ack_tx, saga_id } => {
self.cmd_saga_start(ack_tx, saga_id);
}
SecClientMsg::SagaList { ack_tx, marker, limit } => {
self.cmd_saga_list(ack_tx, marker, limit);
}
SecClientMsg::SagaGet { ack_tx, saga_id } => {
self.cmd_saga_get(ack_tx, saga_id);
}
SecClientMsg::SagaInjectError {
ack_tx,
saga_id,
node_id,
error_type,
} => {
self.cmd_saga_inject_error(
ack_tx, saga_id, node_id, error_type,
);
}
SecClientMsg::Shutdown => self.cmd_shutdown(),
}
}
fn cmd_saga_create(
&mut self,
ack_tx: oneshot::Sender<
Result<BoxFuture<'static, SagaResult>, anyhow::Error>,
>,
saga_id: SagaId,
template_params: Box<dyn TemplateParams>,
dag: Arc<SagaDag>,
) {
self.do_saga_create(ack_tx, saga_id, template_params, dag, false);
}
fn do_saga_create(
&mut self,
ack_tx: oneshot::Sender<
Result<BoxFuture<'static, SagaResult>, anyhow::Error>,
>,
saga_id: SagaId,
template_params: Box<dyn TemplateParams>,
dag: Arc<SagaDag>,
autostart: bool,
) {
let log = self.log.new(o!(
"saga_id" => saga_id.to_string(),
"saga_name" => dag.saga_name.to_string(),
));
let serialized_dag = serde_json::to_value(&dag)
.map_err(ActionError::new_serialize)
.context("serializing new saga dag")
.unwrap();
debug!(&log, "saga create";
"dag" => serde_json::to_string(&serialized_dag).unwrap()
);
let saga_create = SagaCreateParams {
id: saga_id,
name: dag.saga_name.clone(),
dag: serialized_dag.clone(),
state: SagaCachedState::Running,
};
let store = Arc::clone(&self.sec_store);
let create_future = async move {
let result = store
.saga_create(saga_create)
.await
.context("creating saga record");
if let Err(error) = result {
Sec::client_respond(&log, ack_tx, Err(error));
None
} else {
Some(SecStep::SagaInsert(SagaInsertData {
ack_tx,
log,
saga_id,
template_params,
serialized_dag,
autostart,
}))
}
}
.boxed();
self.futures.push(create_future);
}
fn cmd_saga_resume(
&mut self,
ack_tx: oneshot::Sender<
Result<BoxFuture<'static, SagaResult>, anyhow::Error>,
>,
saga_id: SagaId,
template_params: Box<dyn TemplateParams>,
dag: Arc<SagaDag>,
) {
let log = self.log.new(o!(
"saga_id" => saga_id.to_string(),
"saga_name" => dag.saga_name.to_string(),
));
let serialized_dag = serde_json::to_value(&dag)
.map_err(ActionError::new_serialize)
.context("serializing new saga dag")
.unwrap();
info!(&log, "saga resume";
"dag" => serde_json::to_string(&serialized_dag).unwrap()
);
self.saga_insert(SagaInsertData {
ack_tx,
log,
saga_id,
template_params,
serialized_dag,
autostart: false,
})
}
fn cmd_saga_list(
&self,
ack_tx: oneshot::Sender<Vec<SagaView>>,
marker: Option<SagaId>,
limit: NonZeroU32,
) {
trace!(&self.log, "saga_list");
let log = self.log.new(o!());
let limit = usize::try_from(limit.get()).unwrap();
let futures = match marker {
None => self
.sagas
.values()
.take(limit)
.map(SagaView::from_saga)
.collect::<Vec<_>>(),
Some(marker_value) => {
use std::ops::Bound;
self.sagas
.range((Bound::Excluded(marker_value), Bound::Unbounded))
.take(limit)
.map(|(_, v)| SagaView::from_saga(v))
.collect::<Vec<_>>()
}
};
self.futures.push(
async move {
let views = futures::stream::iter(futures)
.then(|f| f)
.collect::<Vec<SagaView>>()
.await;
Sec::client_respond(&log, ack_tx, views);
None
}
.boxed(),
);
}
fn cmd_saga_get(
&self,
ack_tx: oneshot::Sender<Result<SagaView, ()>>,
saga_id: SagaId,
) {
trace!(&self.log, "saga_get"; "saga_id" => %saga_id);
let maybe_saga = self.saga_lookup(saga_id);
if maybe_saga.is_err() {
Sec::client_respond(&self.log, ack_tx, Err(()));
return;
}
let fut = SagaView::from_saga(maybe_saga.unwrap());
let log = self.log.new(o!());
let the_fut = async move {
let saga_view = fut.await;
Sec::client_respond(&log, ack_tx, Ok(saga_view));
None
};
self.futures.push(the_fut.boxed());
}
fn cmd_saga_inject_error(
&self,
ack_tx: oneshot::Sender<Result<(), anyhow::Error>>,
saga_id: SagaId,
node_id: NodeIndex,
error_type: ErrorInjected,
) {
trace!(
&self.log,
"saga_inject_error";
"saga_id" => %saga_id,
"node_id" => ?node_id,
);
let maybe_saga = self.saga_lookup(saga_id);
if let Err(e) = maybe_saga {
Sec::client_respond(&self.log, ack_tx, Err(e));
return;
}
let saga = maybe_saga.unwrap();
let exec = match &saga.run_state {
SagaRunState::Ready { exec, .. } => Arc::clone(exec),
SagaRunState::Running { exec, .. } => Arc::clone(exec),
SagaRunState::Done { .. } => {
Sec::client_respond(
&self.log,
ack_tx,
Err(anyhow!("saga is not running: {}", saga_id)),
);
return;
}
};
let log = self.log.new(o!());
let fut = async move {
match error_type {
ErrorInjected::Fail => {
exec.inject_error(node_id).await;
}
ErrorInjected::Repeat(repeat) => {
exec.inject_repeat(node_id, repeat).await;
}
}
Sec::client_respond(&log, ack_tx, Ok(()));
None
}
.boxed();
self.futures.push(fut);
}
fn cmd_shutdown(&mut self) {
info!(&self.log, "initiating shutdown");
self.shutdown = true;
}
fn dispatch_exec_message(&mut self, exec_message: SecExecMsg) {
let log = self.log.new(o!());
let store = Arc::clone(&self.sec_store);
match exec_message {
SecExecMsg::LogEvent(log_data) => {
self.futures
.push(Sec::executor_log(log, store, log_data).boxed());
}
SecExecMsg::UpdateCachedState(update_data) => {
self.futures.push(
Sec::executor_update(log, store, update_data).boxed(),
);
}
SecExecMsg::SagaGet(get_data) => {
self.executor_saga_get(get_data);
}
};
}
async fn executor_log(
log: slog::Logger,
store: Arc<dyn SecStore>,
log_data: SagaLogEventData,
) -> Option<SecStep> {
debug!(&log, "saga log event";
"new_state" => ?log_data.event
);
let ack_tx = log_data.ack_tx;
store.record_event(log_data.event).await;
Sec::client_respond(&log, ack_tx, ());
None
}
async fn executor_update(
log: slog::Logger,
store: Arc<dyn SecStore>,
update_data: SagaUpdateCacheData,
) -> Option<SecStep> {
info!(&log, "update for saga cached state";
"saga_id" => update_data.saga_id.to_string(),
"new_state" => ?update_data.updated_state
);
let ack_tx = update_data.ack_tx;
store.saga_update(update_data.saga_id, update_data.updated_state).await;
Sec::client_respond(&log, ack_tx, ());
None
}
fn executor_saga_get(&self, get_data: SagaGetData) {
self.cmd_saga_get(get_data.ack_tx, get_data.saga_id);
}
fn saga_lookup(&self, saga_id: SagaId) -> Result<&Saga, anyhow::Error> {
self.sagas
.get(&saga_id)
.ok_or_else(|| anyhow!("no such saga: {:?}", saga_id))
}
fn saga_remove(&mut self, saga_id: SagaId) -> Result<Saga, anyhow::Error> {
self.sagas
.remove(&saga_id)
.ok_or_else(|| anyhow!("no such saga: {:?}", saga_id))
}
}
struct Saga {
id: SagaId,
log: slog::Logger,
dag: serde_json::Value,
run_state: SagaRunState,
}
#[derive(Debug)]
pub enum SagaRunState {
Ready {
exec: Arc<dyn SagaExecManager>,
waiter: oneshot::Sender<SagaResult>,
},
Running {
exec: Arc<dyn SagaExecManager>,
waiter: oneshot::Sender<SagaResult>,
},
Done {
status: SagaExecStatus,
result: SagaResult,
},
}
enum SecStep {
SagaInsert(SagaInsertData),
SagaDone(SagaDoneData),
}
struct SagaInsertData {
log: slog::Logger,
saga_id: SagaId,
template_params: Box<dyn TemplateParams>,
serialized_dag: serde_json::Value,
ack_tx:
oneshot::Sender<Result<BoxFuture<'static, SagaResult>, anyhow::Error>>,
autostart: bool,
}
struct SagaDoneData {
saga_id: SagaId,
result: SagaResult,
status: SagaExecStatus,
}
#[derive(Deserialize, Serialize)]
pub struct SagaSerialized {
pub saga_id: SagaId,
pub dag: serde_json::Value,
pub events: Vec<SagaNodeEvent>,
}
impl TryFrom<SagaSerialized> for SagaLog {
type Error = anyhow::Error;
fn try_from(s: SagaSerialized) -> Result<SagaLog, anyhow::Error> {
SagaLog::new_recover(s.saga_id, s.events)
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::{
ActionContext, ActionError, ActionFunc, DagBuilder, Node, SagaId,
SagaName,
};
use serde::{Deserialize, Serialize};
use slog::Drain;
use std::sync::Mutex;
use uuid::Uuid;
fn new_log() -> slog::Logger {
let decorator = slog_term::TermDecorator::new().build();
let drain = slog_term::FullFormat::new(decorator).build().fuse();
let drain = slog::LevelFilter(drain, slog::Level::Warning).fuse();
let drain = slog_async::Async::new(drain).build().fuse();
slog::Logger::root(drain, slog::o!())
}
fn new_sec(log: &slog::Logger) -> SecClient {
crate::sec(
log.new(slog::o!()),
Arc::new(crate::InMemorySecStore::new()),
)
}
#[derive(Debug, Serialize, Deserialize)]
struct TestParams;
#[derive(Debug)]
struct TestContext {
counters: Mutex<BTreeMap<String, u32>>,
}
impl TestContext {
fn new() -> Self {
TestContext { counters: Mutex::new(BTreeMap::new()) }
}
fn call(&self, name: &str) {
let mut map = self.counters.lock().unwrap();
if let Some(count) = map.get_mut(name) {
*count += 1;
} else {
map.insert(name.to_string(), 1);
}
}
fn get_count(&self, name: &str) -> u32 {
let map = self.counters.lock().unwrap();
if let Some(count) = map.get(name) {
*count
} else {
0
}
}
}
#[derive(Debug)]
struct TestSaga;
impl SagaType for TestSaga {
type ExecContextType = TestContext;
}
fn make_test_saga() -> (Arc<ActionRegistry<TestSaga>>, Arc<SagaDag>) {
async fn do_n1(
ctx: ActionContext<TestSaga>,
) -> Result<i32, ActionError> {
ctx.user_data().call("do_n1");
Ok(1)
}
async fn undo_n1(
ctx: ActionContext<TestSaga>,
) -> Result<(), anyhow::Error> {
ctx.user_data().call("undo_n1");
Ok(())
}
async fn do_n2(
ctx: ActionContext<TestSaga>,
) -> Result<i32, ActionError> {
ctx.user_data().call("do_n2");
Ok(2)
}
async fn undo_n2(
ctx: ActionContext<TestSaga>,
) -> Result<(), anyhow::Error> {
ctx.user_data().call("undo_n2");
Ok(())
}
let mut registry = ActionRegistry::new();
let action_n1 = ActionFunc::new_action("n1_out", do_n1, undo_n1);
registry.register(Arc::clone(&action_n1));
let action_n2 = ActionFunc::new_action("n2_out", do_n2, undo_n2);
registry.register(Arc::clone(&action_n2));
let mut builder = DagBuilder::new(SagaName::new("test-saga"));
builder.append(Node::action("n1_out", "n1", &*action_n1));
builder.append(Node::action("n2_out", "n2", &*action_n2));
(
Arc::new(registry),
Arc::new(SagaDag::new(
builder.build().unwrap(),
serde_json::to_value(TestParams {}).unwrap(),
)),
)
}
struct TestArguments<'a> {
repeat: Option<(NodeIndex, RepeatInjected)>,
fail_node: Option<NodeIndex>,
counts: &'a [Counts],
}
struct Counts {
action: u32,
undo: u32,
}
async fn saga_runner_helper(arguments: TestArguments<'_>) {
let log = new_log();
let sec = new_sec(&log);
let (registry, dag) = make_test_saga();
let saga_id = SagaId(Uuid::new_v4());
let context = Arc::new(TestContext::new());
let saga_future = sec
.saga_create(saga_id, Arc::clone(&context), dag, registry)
.await
.expect("failed to create saga");
if let Some((repeat_node, repeat_operation)) = arguments.repeat {
sec.saga_inject_repeat(saga_id, repeat_node, repeat_operation)
.await
.expect("failed to inject repeat");
}
if let Some(fail_node) = arguments.fail_node {
sec.saga_inject_error(saga_id, fail_node)
.await
.expect("failed to inject error");
}
sec.saga_start(saga_id).await.expect("failed to start saga running");
let result = saga_future.await;
if arguments.fail_node.is_some() {
result.kind.expect_err("should have failed; we injected an error!");
} else {
let output = result.kind.unwrap();
assert_eq!(output.lookup_node_output::<i32>("n1_out").unwrap(), 1);
assert_eq!(output.lookup_node_output::<i32>("n2_out").unwrap(), 2);
}
let counts = &arguments.counts;
assert_eq!(context.get_count("do_n1"), counts[0].action);
assert_eq!(context.get_count("undo_n1"), counts[0].undo);
assert_eq!(context.get_count("do_n2"), counts[1].action);
assert_eq!(context.get_count("undo_n2"), counts[1].undo);
}
#[tokio::test]
async fn test_saga_create_and_start_executes_saga() {
saga_runner_helper(TestArguments {
repeat: None,
fail_node: None,
counts: &[
Counts { action: 1, undo: 0 },
Counts { action: 1, undo: 0 },
],
})
.await;
}
#[tokio::test]
async fn test_saga_inject_repeat_and_then_succeed() {
saga_runner_helper(TestArguments {
repeat: Some((
NodeIndex::new(0),
RepeatInjected {
action: NonZeroU32::new(2).unwrap(),
undo: NonZeroU32::new(1).unwrap(),
},
)),
fail_node: None,
counts: &[
Counts { action: 2, undo: 0 },
Counts { action: 1, undo: 0 },
],
})
.await;
}
#[tokio::test]
async fn test_saga_inject_repeat_and_then_fail() {
saga_runner_helper(TestArguments {
repeat: Some((
NodeIndex::new(0),
RepeatInjected {
action: NonZeroU32::new(2).unwrap(),
undo: NonZeroU32::new(1).unwrap(),
},
)),
fail_node: Some(NodeIndex::new(1)),
counts: &[
Counts { action: 2, undo: 1 },
Counts { action: 0, undo: 0 },
],
})
.await;
}
#[tokio::test]
async fn test_saga_inject_repeat_fail_and_repeat_undo() {
saga_runner_helper(TestArguments {
repeat: Some((
NodeIndex::new(0),
RepeatInjected {
action: NonZeroU32::new(2).unwrap(),
undo: NonZeroU32::new(2).unwrap(),
},
)),
fail_node: Some(NodeIndex::new(1)),
counts: &[
Counts { action: 2, undo: 2 },
Counts { action: 0, undo: 0 },
],
})
.await;
}
#[tokio::test]
async fn test_saga_inject_and_fail_repeat_undo_only() {
saga_runner_helper(TestArguments {
repeat: Some((
NodeIndex::new(0),
RepeatInjected {
action: NonZeroU32::new(1).unwrap(),
undo: NonZeroU32::new(2).unwrap(),
},
)),
fail_node: Some(NodeIndex::new(1)),
counts: &[
Counts { action: 1, undo: 2 },
Counts { action: 0, undo: 0 },
],
})
.await;
}
#[tokio::test]
async fn test_saga_inject_and_fail_repeat_many_times() {
saga_runner_helper(TestArguments {
repeat: Some((
NodeIndex::new(0),
RepeatInjected {
action: NonZeroU32::new(3).unwrap(),
undo: NonZeroU32::new(5).unwrap(),
},
)),
fail_node: Some(NodeIndex::new(1)),
counts: &[
Counts { action: 3, undo: 5 },
Counts { action: 0, undo: 0 },
],
})
.await;
}
#[tokio::test]
async fn test_saga_fails_after_error_injection() {
saga_runner_helper(TestArguments {
repeat: None,
fail_node: Some(NodeIndex::new(0)),
counts: &[
Counts { action: 0, undo: 0 },
Counts { action: 0, undo: 0 },
],
})
.await;
}
#[tokio::test]
async fn test_saga_create_without_start_does_not_run_saga() {
let log = new_log();
let sec = new_sec(&log);
let (registry, dag) = make_test_saga();
let saga_id = SagaId(Uuid::new_v4());
let context = Arc::new(TestContext::new());
let saga_future = sec
.saga_create(saga_id, Arc::clone(&context), dag, registry)
.await
.expect("failed to create saga");
tokio::select! {
_ = saga_future => {
panic!("saga_create future shouldn't complete without start");
},
_ = tokio::time::sleep(tokio::time::Duration::from_millis(1)) => {},
}
assert_eq!(context.get_count("do_n1"), 0);
assert_eq!(context.get_count("undo_n1"), 0);
assert_eq!(context.get_count("do_n2"), 0);
assert_eq!(context.get_count("undo_n2"), 0);
}
#[tokio::test]
async fn test_saga_resume_and_start_executes_saga() {
let log = new_log();
let sec = new_sec(&log);
let (registry, dag) = make_test_saga();
let saga_id = SagaId(Uuid::new_v4());
let context = Arc::new(TestContext::new());
let saga_future = sec
.saga_resume(
saga_id,
Arc::clone(&context),
serde_json::to_value(Arc::try_unwrap(dag).unwrap()).unwrap(),
registry,
vec![],
)
.await
.expect("failed to resume saga");
sec.saga_start(saga_id).await.expect("failed to start saga running");
let result = saga_future.await;
let output = result.kind.unwrap();
assert_eq!(output.lookup_node_output::<i32>("n1_out").unwrap(), 1);
assert_eq!(context.get_count("do_n1"), 1);
assert_eq!(context.get_count("undo_n1"), 0);
assert_eq!(context.get_count("do_n2"), 1);
assert_eq!(context.get_count("undo_n2"), 0);
}
#[tokio::test]
async fn test_saga_resuming_already_created_saga_fails() {
let log = new_log();
let sec = new_sec(&log);
let (registry, dag) = make_test_saga();
let saga_id = SagaId(Uuid::new_v4());
let context = Arc::new(TestContext::new());
let _ = sec
.saga_create(
saga_id,
Arc::clone(&context),
dag.clone(),
registry.clone(),
)
.await
.expect("failed to create saga");
let err = sec
.saga_resume(
saga_id,
Arc::clone(&context),
serde_json::to_value((*dag).clone()).unwrap(),
registry,
vec![],
)
.await
.err()
.expect("Resuming the saga should fail");
assert!(err.to_string().contains("cannot be inserted; already in use"));
}
#[tokio::test]
async fn test_saga_start_without_create_fails() {
let log = new_log();
let sec = new_sec(&log);
let saga_id = SagaId(Uuid::new_v4());
let err = sec
.saga_start(saga_id)
.await
.err()
.expect("Starting an uncreated saga should fail");
assert!(err.to_string().contains("no such saga"));
}
#[tokio::test]
async fn test_sagas_can_only_be_started_once() {
let log = new_log();
let sec = new_sec(&log);
let (registry, dag) = make_test_saga();
let saga_id = SagaId(Uuid::new_v4());
let context = Arc::new(TestContext::new());
let _ = sec
.saga_create(saga_id, Arc::clone(&context), dag, registry)
.await
.expect("failed to create saga");
let _ = sec.saga_start(saga_id).await.expect("failed to start saga");
let err = sec
.saga_start(saga_id)
.await
.err()
.expect("Double starting a saga should fail");
assert!(err.to_string().contains("saga not in \"ready\" state"));
}
}