use std::borrow::Cow;
use std::fmt;
use std::future::Future;
use std::hash::Hash;
use std::pin::Pin;
use std::sync::Arc;
pub use cano_macros::compensatable_task as task;
use serde::Serialize;
use serde::de::DeserializeOwned;
use crate::error::CanoError;
use crate::resource::Resources;
use crate::task::{TaskConfig, TaskResult};
#[crate::saga::task]
pub trait CompensatableTask<TState, TResourceKey = Cow<'static, str>>: Send + Sync
where
TState: Clone + fmt::Debug + Send + Sync + 'static,
TResourceKey: Hash + Eq + Send + Sync + 'static,
{
type Output: Serialize + DeserializeOwned + Send + Sync + 'static;
fn config(&self) -> TaskConfig {
TaskConfig::default()
}
fn name(&self) -> Cow<'static, str> {
Cow::Borrowed(std::any::type_name::<Self>())
}
async fn run(
&self,
res: &Resources<TResourceKey>,
) -> Result<(TaskResult<TState>, Self::Output), CanoError>;
async fn compensate(
&self,
res: &Resources<TResourceKey>,
output: Self::Output,
) -> Result<(), CanoError>;
}
#[derive(Debug, Clone)]
pub(crate) struct CompensationEntry {
pub task_id: Arc<str>,
pub output_blob: Vec<u8>,
}
pub trait ErasedCompensatable<TState, TResourceKey>: Send + Sync
where
TState: Clone + Send + Sync + 'static,
TResourceKey: Hash + Eq + Send + Sync + 'static,
{
fn name(&self) -> Cow<'static, str>;
fn config(&self) -> TaskConfig;
fn run<'a>(&'a self, res: &'a Resources<TResourceKey>) -> ForwardRunFuture<'a, TState>;
fn compensate<'a>(
&'a self,
res: &'a Resources<TResourceKey>,
output_blob: &'a [u8],
) -> CompensateFuture<'a>;
}
pub type ForwardRunFuture<'a, TState> =
Pin<Box<dyn Future<Output = Result<(TaskResult<TState>, Vec<u8>), CanoError>> + Send + 'a>>;
pub type CompensateFuture<'a> = Pin<Box<dyn Future<Output = Result<(), CanoError>> + Send + 'a>>;
async fn run_inline_compensate<TState, TResourceKey, T>(
task: &T,
res: &Resources<TResourceKey>,
output: T::Output,
original_err: CanoError,
) -> Result<(TaskResult<TState>, Vec<u8>), CanoError>
where
TState: Clone + fmt::Debug + Send + Sync + 'static,
TResourceKey: Hash + Eq + Send + Sync + 'static,
T: CompensatableTask<TState, TResourceKey> + ?Sized + 'static,
{
use futures_util::FutureExt;
use std::panic::AssertUnwindSafe;
let task_name = task.name();
#[cfg(feature = "tracing")]
tracing::debug!(
task = %task_name,
"running compensate inline (adapter rejected the run result)"
);
let compensate_fut = task.compensate(res, output);
let outcome = AssertUnwindSafe(compensate_fut).catch_unwind().await;
match outcome {
Ok(Ok(())) => Err(original_err),
Ok(Err(compensate_err)) => {
#[cfg(feature = "tracing")]
tracing::error!(
task = %task_name,
error = %compensate_err,
"inline compensate failed"
);
Err(CanoError::compensation_failed(vec![
original_err,
compensate_err,
]))
}
Err(payload) => {
let panic_msg = crate::workflow::panic_payload_message(&*payload);
#[cfg(feature = "tracing")]
tracing::error!(task = %task_name, panic = %panic_msg, "inline compensate panicked");
Err(CanoError::compensation_failed(vec![
original_err,
CanoError::task_execution(format!(
"inline compensate for `{task_name}` panicked: {panic_msg}"
)),
]))
}
}
}
pub(crate) struct CompensatableAdapter<T>(pub Arc<T>);
impl<TState, TResourceKey, T> ErasedCompensatable<TState, TResourceKey> for CompensatableAdapter<T>
where
TState: Clone + fmt::Debug + Send + Sync + 'static,
TResourceKey: Hash + Eq + Send + Sync + 'static,
T: CompensatableTask<TState, TResourceKey> + 'static,
{
fn name(&self) -> Cow<'static, str> {
self.0.name()
}
fn config(&self) -> TaskConfig {
self.0.config()
}
fn run<'a>(&'a self, res: &'a Resources<TResourceKey>) -> ForwardRunFuture<'a, TState> {
let attempt_timeout = self.0.config().attempt_timeout;
Box::pin(async move {
let forward_fut = self.0.run(res);
let (state, output) = match attempt_timeout {
Some(d) => match tokio::time::timeout(d, forward_fut).await {
Ok(inner) => inner?,
Err(_) => {
return Err(CanoError::timeout(format!(
"compensatable task `{}` forward run exceeded attempt_timeout {d:?}",
self.0.name()
)));
}
},
None => forward_fut.await?,
};
if let TaskResult::Split(_) = &state {
let split_err = CanoError::workflow(format!(
"Compensatable task `{}` returned a split result — split states cannot be compensatable",
self.0.name()
));
return run_inline_compensate(self.0.as_ref(), res, output, split_err).await;
}
match serde_json::to_vec(&output) {
Ok(blob) => Ok((state, blob)),
Err(serialize_err) => {
let serialize_err = CanoError::task_execution(format!(
"serialize compensation output for `{}`: {serialize_err}",
self.0.name()
));
run_inline_compensate(self.0.as_ref(), res, output, serialize_err).await
}
}
})
}
fn compensate<'a>(
&'a self,
res: &'a Resources<TResourceKey>,
output_blob: &'a [u8],
) -> CompensateFuture<'a> {
Box::pin(async move {
let output: T::Output = serde_json::from_slice(output_blob).map_err(|e| {
CanoError::generic(format!(
"deserialize compensation output for `{}`: {e}",
self.0.name()
))
})?;
self.0.compensate(res, output).await
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
enum S {
A,
B,
}
type Log = Arc<std::sync::Mutex<Vec<u32>>>;
fn log() -> Log {
Arc::new(std::sync::Mutex::new(Vec::new()))
}
#[derive(Clone)]
struct Comp {
output: u32,
forward_split: bool,
forward_err: bool,
comp_fail: bool,
comp_panic: bool,
log: Log,
}
impl Comp {
fn ok(output: u32, log: &Log) -> Self {
Self {
output,
forward_split: false,
forward_err: false,
comp_fail: false,
comp_panic: false,
log: Arc::clone(log),
}
}
}
#[crate::saga::task]
impl CompensatableTask<S> for Comp {
type Output = u32;
fn config(&self) -> TaskConfig {
TaskConfig::minimal()
}
async fn run(&self, _res: &Resources) -> Result<(TaskResult<S>, u32), CanoError> {
if self.forward_err {
return Err(CanoError::task_execution("forward failed"));
}
if self.forward_split {
return Ok((TaskResult::Split(vec![S::A]), self.output));
}
Ok((TaskResult::Single(S::B), self.output))
}
async fn compensate(&self, _res: &Resources, output: u32) -> Result<(), CanoError> {
self.log.lock().unwrap().push(output);
if self.comp_panic {
panic!("compensate panicked");
}
if self.comp_fail {
return Err(CanoError::generic("compensate failed"));
}
Ok(())
}
}
#[derive(Clone)]
struct FailSer;
impl serde::Serialize for FailSer {
fn serialize<Ser>(&self, _s: Ser) -> Result<Ser::Ok, Ser::Error>
where
Ser: serde::Serializer,
{
Err(<Ser::Error as serde::ser::Error>::custom(
"intentional serialize failure",
))
}
}
impl<'de> serde::Deserialize<'de> for FailSer {
fn deserialize<D>(_d: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
Ok(FailSer)
}
}
#[derive(Clone)]
struct UnserTask {
log: Log,
}
#[crate::saga::task]
impl CompensatableTask<S> for UnserTask {
type Output = FailSer;
async fn run(&self, _res: &Resources) -> Result<(TaskResult<S>, FailSer), CanoError> {
Ok((TaskResult::Single(S::B), FailSer))
}
async fn compensate(&self, _res: &Resources, _output: FailSer) -> Result<(), CanoError> {
self.log.lock().unwrap().push(999); Ok(())
}
}
#[derive(Clone)]
struct SlowTask;
#[crate::saga::task]
impl CompensatableTask<S> for SlowTask {
type Output = u32;
fn config(&self) -> TaskConfig {
TaskConfig::minimal().with_attempt_timeout(Duration::from_millis(20))
}
async fn run(&self, _res: &Resources) -> Result<(TaskResult<S>, u32), CanoError> {
tokio::time::sleep(Duration::from_secs(5)).await;
Ok((TaskResult::Single(S::B), 1))
}
async fn compensate(&self, _res: &Resources, _output: u32) -> Result<(), CanoError> {
Ok(())
}
}
#[test]
fn name_and_config_delegate_to_the_inner_task() {
let l = log();
let adapter = CompensatableAdapter(Arc::new(Comp::ok(1, &l)));
assert!(adapter.name().contains("Comp"));
assert!(adapter.config().attempt_timeout.is_none());
let slow = CompensatableAdapter(Arc::new(SlowTask));
assert_eq!(
slow.config().attempt_timeout,
Some(Duration::from_millis(20))
);
}
#[tokio::test]
async fn forward_success_serializes_output_and_round_trips_through_compensate() {
let l = log();
let adapter = CompensatableAdapter(Arc::new(Comp::ok(42, &l)));
let res = Resources::new();
let (state, blob) = adapter.run(&res).await.expect("forward run ok");
assert!(matches!(state, TaskResult::Single(S::B)));
assert_eq!(serde_json::from_slice::<u32>(&blob).unwrap(), 42);
adapter
.compensate(&res, &blob)
.await
.expect("compensate ok");
assert_eq!(
*l.lock().unwrap(),
vec![42],
"compensate received the round-tripped output"
);
}
#[tokio::test]
async fn forward_error_propagates_without_running_compensate() {
let l = log();
let mut task = Comp::ok(1, &l);
task.forward_err = true;
let adapter = CompensatableAdapter(Arc::new(task));
let res = Resources::new();
let err = adapter.run(&res).await.unwrap_err();
assert!(err.to_string().contains("forward failed"));
assert!(
l.lock().unwrap().is_empty(),
"compensate must not run on a forward failure"
);
}
#[tokio::test]
async fn forward_run_respects_attempt_timeout() {
let adapter = CompensatableAdapter(Arc::new(SlowTask));
let res = Resources::new();
let err = adapter.run(&res).await.unwrap_err();
assert!(
err.to_string().contains("exceeded attempt_timeout"),
"got: {err}"
);
}
#[tokio::test]
async fn split_result_runs_inline_compensate_then_returns_split_error() {
let l = log();
let mut task = Comp::ok(7, &l);
task.forward_split = true;
let adapter = CompensatableAdapter(Arc::new(task));
let res = Resources::new();
let err = adapter.run(&res).await.unwrap_err();
assert!(
err.to_string().contains("split result")
|| err
.to_string()
.contains("split states cannot be compensatable"),
"got: {err}"
);
assert_eq!(*l.lock().unwrap(), vec![7]);
}
#[tokio::test]
async fn split_with_failing_compensate_yields_compensation_failed() {
let l = log();
let mut task = Comp::ok(7, &l);
task.forward_split = true;
task.comp_fail = true;
let adapter = CompensatableAdapter(Arc::new(task));
let res = Resources::new();
match adapter.run(&res).await.unwrap_err() {
CanoError::CompensationFailed { errors } => {
assert_eq!(errors.len(), 2, "original split error + compensate error");
assert!(errors[1].to_string().contains("compensate failed"));
}
other => panic!("expected CompensationFailed, got {other:?}"),
}
assert_eq!(*l.lock().unwrap(), vec![7], "compensate was attempted");
}
#[tokio::test]
async fn unserializable_output_runs_inline_compensate() {
let l = log();
let adapter = CompensatableAdapter(Arc::new(UnserTask {
log: Arc::clone(&l),
}));
let res = Resources::new();
let err = adapter.run(&res).await.unwrap_err();
assert!(
err.to_string().contains("serialize compensation output"),
"got: {err}"
);
assert_eq!(*l.lock().unwrap(), vec![999]);
}
#[tokio::test]
async fn inline_compensate_panic_becomes_compensation_failed() {
let l = log();
let mut task = Comp::ok(7, &l);
task.forward_split = true;
task.comp_panic = true;
let adapter = CompensatableAdapter(Arc::new(task));
let res = Resources::new();
match adapter.run(&res).await.unwrap_err() {
CanoError::CompensationFailed { errors } => {
assert_eq!(errors.len(), 2);
assert!(
errors[1].to_string().contains("panicked"),
"got: {}",
errors[1]
);
}
other => panic!("expected CompensationFailed, got {other:?}"),
}
assert_eq!(*l.lock().unwrap(), vec![7]);
}
#[tokio::test]
async fn run_inline_compensate_returns_original_error_on_clean_rollback() {
let l = log();
let task = Comp::ok(42, &l);
let res = Resources::new();
let original = CanoError::task_execution("the real failure");
let out: Result<(TaskResult<S>, Vec<u8>), CanoError> =
run_inline_compensate(&task, &res, 42, original).await;
assert!(out.unwrap_err().to_string().contains("the real failure"));
assert_eq!(*l.lock().unwrap(), vec![42]);
}
#[tokio::test]
async fn compensate_with_corrupt_blob_errors() {
let l = log();
let adapter = CompensatableAdapter(Arc::new(Comp::ok(1, &l)));
let res = Resources::new();
let err = adapter
.compensate(&res, b"not valid json")
.await
.unwrap_err();
assert!(
err.to_string().contains("deserialize compensation output"),
"got: {err}"
);
assert!(
l.lock().unwrap().is_empty(),
"compensate body must not run when the blob can't be deserialized"
);
}
}