mod compfn;
use std::fmt::{Debug, Display, Formatter};
use std::future::Future;
use std::pin::Pin;
use std::rc::Rc;
use crate::bindings::golem::api::host::{OplogIndex, get_oplog_index, set_oplog_index};
use crate::mark_atomic_operation;
pub use compfn::*;
pub type LocalBoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + 'a>>;
#[inline]
pub fn boxed<'a, T>(fut: impl Future<Output = T> + 'a) -> LocalBoxFuture<'a, T> {
Box::pin(fut)
}
#[allow(async_fn_in_trait)]
pub trait Operation: Clone {
type In: Clone;
type Out: Clone;
type Err: Clone;
async fn execute(&self, input: Self::In) -> Result<Self::Out, Self::Err>;
async fn compensate(&self, input: Self::In, result: Self::Out) -> Result<(), Self::Err>;
}
#[allow(clippy::type_complexity)]
pub fn operation<In: Clone, Out: Clone, Err: Clone>(
execute_fn: impl Fn(In) -> Pin<Box<dyn Future<Output = Result<Out, Err>>>> + 'static,
compensate_fn: impl Fn(In, Out) -> Pin<Box<dyn Future<Output = Result<(), Err>>>> + 'static,
) -> impl Operation<In = In, Out = Out, Err = Err> {
FnOperation {
execute_fn: Rc::new(execute_fn),
compensate_fn: Rc::new(compensate_fn),
}
}
pub fn sync_operation<In: Clone + 'static, Out: Clone + 'static, Err: Clone + 'static>(
execute_fn: impl Fn(In) -> Result<Out, Err> + 'static,
compensate_fn: impl Fn(In, Out) -> Result<(), Err> + 'static,
) -> impl Operation<In = In, Out = Out, Err = Err> {
let execute_fn = Rc::new(execute_fn);
let compensate_fn = Rc::new(compensate_fn);
operation(
move |input| {
let f = execute_fn.clone();
Box::pin(async move { f(input) })
},
move |input, output| {
let f = compensate_fn.clone();
Box::pin(async move { f(input, output) })
},
)
}
#[allow(clippy::type_complexity)]
struct FnOperation<In, Out, Err> {
execute_fn: Rc<dyn Fn(In) -> Pin<Box<dyn Future<Output = Result<Out, Err>>>>>,
compensate_fn: Rc<dyn Fn(In, Out) -> Pin<Box<dyn Future<Output = Result<(), Err>>>>>,
}
impl<In, Out, Err> Clone for FnOperation<In, Out, Err> {
fn clone(&self) -> Self {
Self {
execute_fn: self.execute_fn.clone(),
compensate_fn: self.compensate_fn.clone(),
}
}
}
impl<In: Clone, Out: Clone, Err: Clone> Operation for FnOperation<In, Out, Err> {
type In = In;
type Out = Out;
type Err = Err;
async fn execute(&self, input: In) -> Result<Out, Err> {
(self.execute_fn)(input).await
}
async fn compensate(&self, input: In, result: Out) -> Result<(), Err> {
(self.compensate_fn)(input, result).await
}
}
pub type TransactionResult<Out, Err> = Result<Out, TransactionFailure<Err>>;
#[derive(Debug)]
pub enum TransactionFailure<Err> {
FailedAndRolledBackCompletely(Err),
FailedAndRolledBackPartially {
failure: Err,
compensation_failure: Err,
},
}
impl<Err: Display> Display for TransactionFailure<Err> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
TransactionFailure::FailedAndRolledBackCompletely(err) => {
write!(
f,
"Transaction failed with {err} and rolled back completely."
)
}
TransactionFailure::FailedAndRolledBackPartially {
failure,
compensation_failure,
} => write!(
f,
"Transaction failed with {failure} and rolled back partially; compensation failed with: {compensation_failure}."
),
}
}
}
pub async fn fallible_transaction<Out, Err>(
f: impl for<'a> FnOnce(&'a mut FallibleTransaction<Err>) -> LocalBoxFuture<'a, Result<Out, Err>>,
) -> TransactionResult<Out, Err>
where
Err: Clone + 'static,
{
let mut transaction = FallibleTransaction::new();
match f(&mut transaction).await {
Ok(output) => Ok(output),
Err(error) => Err(transaction.on_fail(error).await),
}
}
pub async fn infallible_transaction<Out>(
f: impl for<'a> FnOnce(&'a mut InfallibleTransaction) -> LocalBoxFuture<'a, Out>,
) -> Out {
let oplog_index = get_oplog_index();
let _atomic_region = mark_atomic_operation();
let mut transaction = InfallibleTransaction::new(oplog_index);
f(&mut transaction).await
}
pub async fn infallible_transaction_with_strong_rollback_guarantees<Out>(
_f: impl for<'a> FnOnce(&'a mut InfallibleTransaction) -> LocalBoxFuture<'a, Out>,
) -> Out {
unimplemented!()
}
pub async fn transaction<Out: 'static, Err, T>(
f: impl for<'a> FnOnce(&'a mut T) -> LocalBoxFuture<'a, Result<Out, Err>>,
) -> TransactionResult<Out, Err>
where
T: Transaction<Err>,
{
T::run(f).await
}
#[allow(clippy::type_complexity)]
struct CompensationAction<Err> {
action: Box<dyn FnOnce() -> Pin<Box<dyn Future<Output = Result<(), Err>>>>>,
}
impl<Err> CompensationAction<Err> {
pub async fn execute(self) -> Result<(), Err> {
(self.action)().await
}
}
pub struct FallibleTransaction<Err> {
compensations: Vec<CompensationAction<Err>>,
}
impl<Err: Clone + 'static> FallibleTransaction<Err> {
fn new() -> Self {
Self {
compensations: Vec::new(),
}
}
pub async fn execute<OpIn: Clone + 'static, OpOut: Clone + 'static>(
&mut self,
operation: impl Operation<In = OpIn, Out = OpOut, Err = Err> + 'static,
input: OpIn,
) -> Result<OpOut, Err> {
let result = operation.execute(input.clone()).await;
if let Ok(output) = &result {
let cloned_op = operation.clone();
let cloned_out = output.clone();
self.compensations.push(CompensationAction {
action: Box::new(move || {
Box::pin(async move {
cloned_op
.compensate(input.clone(), cloned_out.clone())
.await
})
}),
});
}
result
}
async fn on_fail(&mut self, failure: Err) -> TransactionFailure<Err> {
for compensation_action in self.compensations.drain(..).rev() {
if let Err(compensation_failure) = compensation_action.execute().await {
return TransactionFailure::FailedAndRolledBackPartially {
failure,
compensation_failure,
};
}
}
TransactionFailure::FailedAndRolledBackCompletely(failure)
}
}
pub struct InfallibleTransaction {
begin_oplog_index: OplogIndex,
compensations: Vec<CompensationAction<()>>,
}
impl InfallibleTransaction {
fn new(begin_oplog_index: OplogIndex) -> Self {
Self {
begin_oplog_index,
compensations: Vec::new(),
}
}
pub async fn execute<
OpIn: Clone + 'static,
OpOut: Clone + 'static,
OpErr: Debug + Clone + 'static,
>(
&mut self,
operation: impl Operation<In = OpIn, Out = OpOut, Err = OpErr> + 'static,
input: OpIn,
) -> OpOut {
match operation.execute(input.clone()).await {
Ok(output) => {
let cloned_op = operation.clone();
let cloned_out = output.clone();
self.compensations.push(CompensationAction {
action: Box::new(move || {
Box::pin(async move {
cloned_op
.compensate(input.clone(), cloned_out.clone())
.await
.expect("Compensation action failed");
Ok(())
})
}),
});
output
}
Err(_) => {
self.retry().await;
unreachable!()
}
}
}
pub async fn retry(&mut self) {
for compensation_action in self.compensations.drain(..).rev() {
let _ = compensation_action.execute().await;
}
set_oplog_index(self.begin_oplog_index);
}
}
#[allow(async_fn_in_trait)]
pub trait Transaction<Err> {
async fn execute<OpIn: Clone + 'static, OpOut: Clone + 'static>(
&mut self,
operation: impl Operation<In = OpIn, Out = OpOut, Err = Err> + 'static,
input: OpIn,
) -> Result<OpOut, Err>;
async fn fail(&mut self, error: Err) -> Result<(), Err>;
async fn run<Out: 'static>(
f: impl for<'a> FnOnce(&'a mut Self) -> LocalBoxFuture<'a, Result<Out, Err>>,
) -> TransactionResult<Out, Err>;
}
impl<Err: Clone + 'static> Transaction<Err> for FallibleTransaction<Err> {
async fn execute<OpIn: Clone + 'static, OpOut: Clone + 'static>(
&mut self,
operation: impl Operation<In = OpIn, Out = OpOut, Err = Err> + 'static,
input: OpIn,
) -> Result<OpOut, Err> {
FallibleTransaction::execute(self, operation, input).await
}
async fn fail(&mut self, error: Err) -> Result<(), Err> {
Err(error)
}
async fn run<Out: 'static>(
f: impl for<'a> FnOnce(&'a mut Self) -> LocalBoxFuture<'a, Result<Out, Err>>,
) -> TransactionResult<Out, Err> {
fallible_transaction(f).await
}
}
impl<Err: Debug + Clone + 'static> Transaction<Err> for InfallibleTransaction {
async fn execute<OpIn: Clone + 'static, OpOut: Clone + 'static>(
&mut self,
operation: impl Operation<In = OpIn, Out = OpOut, Err = Err> + 'static,
input: OpIn,
) -> Result<OpOut, Err> {
Ok(InfallibleTransaction::execute(self, operation, input).await)
}
async fn fail(&mut self, error: Err) -> Result<(), Err> {
InfallibleTransaction::retry(self).await;
Err(error)
}
async fn run<Out: 'static>(
f: impl for<'a> FnOnce(&'a mut Self) -> LocalBoxFuture<'a, Result<Out, Err>>,
) -> TransactionResult<Out, Err> {
Ok(infallible_transaction(|tx| -> LocalBoxFuture<'_, Out> {
let fut = f(tx);
Box::pin(async { fut.await.unwrap() })
})
.await)
}
}
#[cfg(test)]
mod tests {
use std::cell::RefCell;
use std::rc::Rc;
use test_r::test;
use crate::{boxed, fallible_transaction, infallible_transaction, sync_operation};
#[test]
#[ignore]
async fn tx_test_1() {
let log = Rc::new(RefCell::new(Vec::new()));
let log1 = log.clone();
let log2 = log.clone();
let log3 = log.clone();
let log4 = log.clone();
let op1 = sync_operation(
move |input: String| {
log1.borrow_mut().push(format!("op1 execute {input}"));
Ok(())
},
move |input: String, _| {
log2.borrow_mut().push(format!("op1 rollback {input}"));
Ok(())
},
);
let op2 = sync_operation(
move |_: ()| {
log3.clone().borrow_mut().push("op2 execute".to_string());
Err::<(), &str>("op2 error")
},
move |_: (), _| {
log4.clone().borrow_mut().push("op2 rollback".to_string());
Ok(())
},
);
let result = fallible_transaction(|tx| {
boxed(async move {
println!("First we execute op1");
tx.execute(op1, "hello".to_string()).await?;
println!("Then execute op2");
tx.execute(op2, ()).await?;
println!("Finally compute a result");
Ok(11)
})
})
.await;
println!("{log:?}");
println!("{result:?}");
}
#[test]
#[ignore]
async fn tx_test_2() {
let log = Rc::new(RefCell::new(Vec::new()));
let log1 = log.clone();
let log2 = log.clone();
let log3 = log.clone();
let log4 = log.clone();
let op1 = sync_operation(
move |input: String| {
log1.borrow_mut().push(format!("op1 execute {input}"));
Ok::<(), ()>(())
},
move |input: String, _| {
log2.borrow_mut().push(format!("op1 rollback {input}"));
Ok(())
},
);
let op2 = sync_operation(
move |_: ()| {
log3.clone().borrow_mut().push("op2 execute".to_string());
Err::<(), &str>("op2 error")
},
move |_: (), r| {
log4.clone()
.borrow_mut()
.push(format!("op2 rollback {r:?}"));
Ok(())
},
);
let result = infallible_transaction(|tx| {
boxed(async move {
println!("First we execute op1");
tx.execute(op1, "hello".to_string()).await;
println!("Then execute op2");
tx.execute(op2, ()).await;
println!("Finally compute a result");
11
})
})
.await;
println!("{log:?}");
println!("{result:?}");
}
}
#[cfg(test)]
#[cfg(feature = "macro")]
mod macro_tests {
use crate::{boxed, fallible_transaction, infallible_transaction};
use golem_rust_macro::golem_operation;
use test_r::test;
mod golem_rust {
pub use crate::*;
}
#[golem_operation(compensation=test_compensation)]
fn test_operation(input1: u64, input2: f32) -> Result<bool, String> {
println!("Op input: {input1}, {input2}");
Ok(true)
}
fn test_compensation(_: bool, input1: u64, input2: f32) -> Result<(), String> {
println!("Compensation input: {input1}, {input2}");
Ok(())
}
#[golem_operation(compensation=test_compensation_2)]
fn test_operation_2(input1: u64, input2: f32) -> Result<bool, String> {
println!("Op input: {input1}, {input2}");
Ok(true)
}
fn test_compensation_2(result: bool) -> Result<(), String> {
println!("Compensation for operation result {result:?}");
Ok(())
}
#[golem_operation(compensation=test_compensation_3)]
fn test_operation_3(input: String) -> Result<(), String> {
println!("Op input: {input}");
Ok(())
}
fn test_compensation_3() -> Result<(), String> {
println!("Compensation for operation, not using any input");
Ok(())
}
#[golem_operation(compensation=test_compensation_4)]
fn test_operation_4(input: u64) -> Result<(), String> {
println!("Op input: {input}");
Ok(())
}
fn test_compensation_4(_: (), input: u64) -> Result<(), String> {
println!("Compensation for operation with single input {input}");
Ok(())
}
#[test]
#[ignore]
async fn tx_test_1() {
let result = fallible_transaction(|tx| {
boxed(async move {
println!("Executing the annotated function as an operation directly");
tx.test_operation(1, 0.1).await?;
tx.test_operation_2(1, 0.1).await?;
tx.test_operation_3("test".to_string()).await?;
tx.test_operation_4(1).await?;
Ok(11)
})
})
.await;
println!("{result:?}");
}
#[test]
#[ignore]
async fn tx_test_2() {
let result = infallible_transaction(|tx| {
boxed(async move {
println!("Executing the annotated function as an operation directly");
let _ = tx.test_operation(1, 0.1).await;
11
})
})
.await;
println!("{result:?}");
}
}