1mod compfn;
16
17use std::fmt::{Debug, Display, Formatter};
18use std::future::Future;
19use std::pin::Pin;
20use std::rc::Rc;
21
22use crate::bindings::golem::api::host::{OplogIndex, get_oplog_index, set_oplog_index};
23use crate::mark_atomic_operation;
24
25pub use compfn::*;
26
27pub type LocalBoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + 'a>>;
32
33#[inline]
43pub fn boxed<'a, T>(fut: impl Future<Output = T> + 'a) -> LocalBoxFuture<'a, T> {
44 Box::pin(fut)
45}
46
47#[allow(async_fn_in_trait)]
52pub trait Operation: Clone {
53 type In: Clone;
54 type Out: Clone;
55 type Err: Clone;
56
57 async fn execute(&self, input: Self::In) -> Result<Self::Out, Self::Err>;
59
60 async fn compensate(&self, input: Self::In, result: Self::Out) -> Result<(), Self::Err>;
62}
63
64#[allow(clippy::type_complexity)]
70pub fn operation<In: Clone, Out: Clone, Err: Clone>(
71 execute_fn: impl Fn(In) -> Pin<Box<dyn Future<Output = Result<Out, Err>>>> + 'static,
72 compensate_fn: impl Fn(In, Out) -> Pin<Box<dyn Future<Output = Result<(), Err>>>> + 'static,
73) -> impl Operation<In = In, Out = Out, Err = Err> {
74 FnOperation {
75 execute_fn: Rc::new(execute_fn),
76 compensate_fn: Rc::new(compensate_fn),
77 }
78}
79
80pub fn sync_operation<In: Clone + 'static, Out: Clone + 'static, Err: Clone + 'static>(
84 execute_fn: impl Fn(In) -> Result<Out, Err> + 'static,
85 compensate_fn: impl Fn(In, Out) -> Result<(), Err> + 'static,
86) -> impl Operation<In = In, Out = Out, Err = Err> {
87 let execute_fn = Rc::new(execute_fn);
88 let compensate_fn = Rc::new(compensate_fn);
89 operation(
90 move |input| {
91 let f = execute_fn.clone();
92 Box::pin(async move { f(input) })
93 },
94 move |input, output| {
95 let f = compensate_fn.clone();
96 Box::pin(async move { f(input, output) })
97 },
98 )
99}
100
101#[allow(clippy::type_complexity)]
102struct FnOperation<In, Out, Err> {
103 execute_fn: Rc<dyn Fn(In) -> Pin<Box<dyn Future<Output = Result<Out, Err>>>>>,
104 compensate_fn: Rc<dyn Fn(In, Out) -> Pin<Box<dyn Future<Output = Result<(), Err>>>>>,
105}
106
107impl<In, Out, Err> Clone for FnOperation<In, Out, Err> {
108 fn clone(&self) -> Self {
109 Self {
110 execute_fn: self.execute_fn.clone(),
111 compensate_fn: self.compensate_fn.clone(),
112 }
113 }
114}
115
116impl<In: Clone, Out: Clone, Err: Clone> Operation for FnOperation<In, Out, Err> {
117 type In = In;
118 type Out = Out;
119 type Err = Err;
120
121 async fn execute(&self, input: In) -> Result<Out, Err> {
122 (self.execute_fn)(input).await
123 }
124
125 async fn compensate(&self, input: In, result: Out) -> Result<(), Err> {
126 (self.compensate_fn)(input, result).await
127 }
128}
129
130pub type TransactionResult<Out, Err> = Result<Out, TransactionFailure<Err>>;
132
133#[derive(Debug)]
135pub enum TransactionFailure<Err> {
136 FailedAndRolledBackCompletely(Err),
138 FailedAndRolledBackPartially {
141 failure: Err,
142 compensation_failure: Err,
143 },
144}
145
146impl<Err: Display> Display for TransactionFailure<Err> {
147 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
148 match self {
149 TransactionFailure::FailedAndRolledBackCompletely(err) => {
150 write!(
151 f,
152 "Transaction failed with {err} and rolled back completely."
153 )
154 }
155 TransactionFailure::FailedAndRolledBackPartially {
156 failure,
157 compensation_failure,
158 } => write!(
159 f,
160 "Transaction failed with {failure} and rolled back partially; compensation failed with: {compensation_failure}."
161 ),
162 }
163 }
164}
165
166pub async fn fallible_transaction<Out, Err>(
170 f: impl for<'a> FnOnce(&'a mut FallibleTransaction<Err>) -> LocalBoxFuture<'a, Result<Out, Err>>,
171) -> TransactionResult<Out, Err>
172where
173 Err: Clone + 'static,
174{
175 let mut transaction = FallibleTransaction::new();
176 match f(&mut transaction).await {
177 Ok(output) => Ok(output),
178 Err(error) => Err(transaction.on_fail(error).await),
179 }
180}
181
182pub async fn infallible_transaction<Out>(
186 f: impl for<'a> FnOnce(&'a mut InfallibleTransaction) -> LocalBoxFuture<'a, Out>,
187) -> Out {
188 let oplog_index = get_oplog_index();
189 let _atomic_region = mark_atomic_operation();
190 let mut transaction = InfallibleTransaction::new(oplog_index);
191 f(&mut transaction).await
192}
193
194pub async fn infallible_transaction_with_strong_rollback_guarantees<Out>(
198 _f: impl for<'a> FnOnce(&'a mut InfallibleTransaction) -> LocalBoxFuture<'a, Out>,
199) -> Out {
200 unimplemented!()
201}
202
203pub async fn transaction<Out: 'static, Err, T>(
210 f: impl for<'a> FnOnce(&'a mut T) -> LocalBoxFuture<'a, Result<Out, Err>>,
211) -> TransactionResult<Out, Err>
212where
213 T: Transaction<Err>,
214{
215 T::run(f).await
216}
217
218#[allow(clippy::type_complexity)]
220struct CompensationAction<Err> {
221 action: Box<dyn FnOnce() -> Pin<Box<dyn Future<Output = Result<(), Err>>>>>,
222}
223
224impl<Err> CompensationAction<Err> {
225 pub async fn execute(self) -> Result<(), Err> {
226 (self.action)().await
227 }
228}
229
230pub struct FallibleTransaction<Err> {
237 compensations: Vec<CompensationAction<Err>>,
238}
239
240impl<Err: Clone + 'static> FallibleTransaction<Err> {
241 fn new() -> Self {
242 Self {
243 compensations: Vec::new(),
244 }
245 }
246
247 pub async fn execute<OpIn: Clone + 'static, OpOut: Clone + 'static>(
248 &mut self,
249 operation: impl Operation<In = OpIn, Out = OpOut, Err = Err> + 'static,
250 input: OpIn,
251 ) -> Result<OpOut, Err> {
252 let result = operation.execute(input.clone()).await;
253 if let Ok(output) = &result {
254 let cloned_op = operation.clone();
255 let cloned_out = output.clone();
256 self.compensations.push(CompensationAction {
257 action: Box::new(move || {
258 Box::pin(async move {
259 cloned_op
260 .compensate(input.clone(), cloned_out.clone())
261 .await
262 })
263 }),
264 });
265 }
266 result
267 }
268
269 async fn on_fail(&mut self, failure: Err) -> TransactionFailure<Err> {
270 for compensation_action in self.compensations.drain(..).rev() {
271 if let Err(compensation_failure) = compensation_action.execute().await {
272 return TransactionFailure::FailedAndRolledBackPartially {
273 failure,
274 compensation_failure,
275 };
276 }
277 }
278 TransactionFailure::FailedAndRolledBackCompletely(failure)
279 }
280}
281
282pub struct InfallibleTransaction {
293 begin_oplog_index: OplogIndex,
294 compensations: Vec<CompensationAction<()>>,
295}
296
297impl InfallibleTransaction {
298 fn new(begin_oplog_index: OplogIndex) -> Self {
299 Self {
300 begin_oplog_index,
301 compensations: Vec::new(),
302 }
303 }
304
305 pub async fn execute<
306 OpIn: Clone + 'static,
307 OpOut: Clone + 'static,
308 OpErr: Debug + Clone + 'static,
309 >(
310 &mut self,
311 operation: impl Operation<In = OpIn, Out = OpOut, Err = OpErr> + 'static,
312 input: OpIn,
313 ) -> OpOut {
314 match operation.execute(input.clone()).await {
315 Ok(output) => {
316 let cloned_op = operation.clone();
317 let cloned_out = output.clone();
318 self.compensations.push(CompensationAction {
319 action: Box::new(move || {
320 Box::pin(async move {
321 cloned_op
322 .compensate(input.clone(), cloned_out.clone())
323 .await
324 .expect("Compensation action failed");
325 Ok(())
326 })
327 }),
328 });
329 output
330 }
331 Err(_) => {
332 self.retry().await;
333 unreachable!()
334 }
335 }
336 }
337
338 pub async fn retry(&mut self) {
340 for compensation_action in self.compensations.drain(..).rev() {
341 let _ = compensation_action.execute().await;
342 }
343 set_oplog_index(self.begin_oplog_index);
344 }
345}
346
347#[allow(async_fn_in_trait)]
351pub trait Transaction<Err> {
352 async fn execute<OpIn: Clone + 'static, OpOut: Clone + 'static>(
353 &mut self,
354 operation: impl Operation<In = OpIn, Out = OpOut, Err = Err> + 'static,
355 input: OpIn,
356 ) -> Result<OpOut, Err>;
357
358 async fn fail(&mut self, error: Err) -> Result<(), Err>;
359
360 async fn run<Out: 'static>(
361 f: impl for<'a> FnOnce(&'a mut Self) -> LocalBoxFuture<'a, Result<Out, Err>>,
362 ) -> TransactionResult<Out, Err>;
363}
364
365impl<Err: Clone + 'static> Transaction<Err> for FallibleTransaction<Err> {
366 async fn execute<OpIn: Clone + 'static, OpOut: Clone + 'static>(
367 &mut self,
368 operation: impl Operation<In = OpIn, Out = OpOut, Err = Err> + 'static,
369 input: OpIn,
370 ) -> Result<OpOut, Err> {
371 FallibleTransaction::execute(self, operation, input).await
372 }
373
374 async fn fail(&mut self, error: Err) -> Result<(), Err> {
375 Err(error)
376 }
377
378 async fn run<Out: 'static>(
379 f: impl for<'a> FnOnce(&'a mut Self) -> LocalBoxFuture<'a, Result<Out, Err>>,
380 ) -> TransactionResult<Out, Err> {
381 fallible_transaction(f).await
382 }
383}
384
385impl<Err: Debug + Clone + 'static> Transaction<Err> for InfallibleTransaction {
386 async fn execute<OpIn: Clone + 'static, OpOut: Clone + 'static>(
387 &mut self,
388 operation: impl Operation<In = OpIn, Out = OpOut, Err = Err> + 'static,
389 input: OpIn,
390 ) -> Result<OpOut, Err> {
391 Ok(InfallibleTransaction::execute(self, operation, input).await)
392 }
393
394 async fn fail(&mut self, error: Err) -> Result<(), Err> {
395 InfallibleTransaction::retry(self).await;
396 Err(error)
397 }
398
399 async fn run<Out: 'static>(
400 f: impl for<'a> FnOnce(&'a mut Self) -> LocalBoxFuture<'a, Result<Out, Err>>,
401 ) -> TransactionResult<Out, Err> {
402 Ok(infallible_transaction(|tx| -> LocalBoxFuture<'_, Out> {
403 let fut = f(tx);
404 Box::pin(async { fut.await.unwrap() })
405 })
406 .await)
407 }
408}
409
410#[cfg(test)]
411mod tests {
412 use std::cell::RefCell;
413 use std::rc::Rc;
414 use test_r::test;
415
416 use crate::{boxed, fallible_transaction, infallible_transaction, sync_operation};
417
418 #[test]
420 #[ignore]
421 async fn tx_test_1() {
422 let log = Rc::new(RefCell::new(Vec::new()));
423
424 let log1 = log.clone();
425 let log2 = log.clone();
426 let log3 = log.clone();
427 let log4 = log.clone();
428
429 let op1 = sync_operation(
430 move |input: String| {
431 log1.borrow_mut().push(format!("op1 execute {input}"));
432 Ok(())
433 },
434 move |input: String, _| {
435 log2.borrow_mut().push(format!("op1 rollback {input}"));
436 Ok(())
437 },
438 );
439
440 let op2 = sync_operation(
441 move |_: ()| {
442 log3.clone().borrow_mut().push("op2 execute".to_string());
443 Err::<(), &str>("op2 error")
444 },
445 move |_: (), _| {
446 log4.clone().borrow_mut().push("op2 rollback".to_string());
447 Ok(())
448 },
449 );
450
451 let result = fallible_transaction(|tx| {
452 boxed(async move {
453 println!("First we execute op1");
454 tx.execute(op1, "hello".to_string()).await?;
455 println!("Then execute op2");
456 tx.execute(op2, ()).await?;
457 println!("Finally compute a result");
458 Ok(11)
459 })
460 })
461 .await;
462
463 println!("{log:?}");
464 println!("{result:?}");
465 }
466
467 #[test]
469 #[ignore]
470 async fn tx_test_2() {
471 let log = Rc::new(RefCell::new(Vec::new()));
472
473 let log1 = log.clone();
474 let log2 = log.clone();
475 let log3 = log.clone();
476 let log4 = log.clone();
477
478 let op1 = sync_operation(
479 move |input: String| {
480 log1.borrow_mut().push(format!("op1 execute {input}"));
481 Ok::<(), ()>(())
482 },
483 move |input: String, _| {
484 log2.borrow_mut().push(format!("op1 rollback {input}"));
485 Ok(())
486 },
487 );
488
489 let op2 = sync_operation(
490 move |_: ()| {
491 log3.clone().borrow_mut().push("op2 execute".to_string());
492 Err::<(), &str>("op2 error")
493 },
494 move |_: (), r| {
495 log4.clone()
496 .borrow_mut()
497 .push(format!("op2 rollback {r:?}"));
498 Ok(())
499 },
500 );
501
502 let result = infallible_transaction(|tx| {
503 boxed(async move {
504 println!("First we execute op1");
505 tx.execute(op1, "hello".to_string()).await;
506 println!("Then execute op2");
507 tx.execute(op2, ()).await;
508 println!("Finally compute a result");
509 11
510 })
511 })
512 .await;
513
514 println!("{log:?}");
515 println!("{result:?}");
516 }
517}
518
519#[cfg(test)]
520#[cfg(feature = "macro")]
521mod macro_tests {
522 use crate::{boxed, fallible_transaction, infallible_transaction};
523 use golem_rust_macro::golem_operation;
524 use test_r::test;
525
526 mod golem_rust {
527 pub use crate::*;
528 }
529
530 #[golem_operation(compensation=test_compensation)]
531 fn test_operation(input1: u64, input2: f32) -> Result<bool, String> {
532 println!("Op input: {input1}, {input2}");
533 Ok(true)
534 }
535
536 fn test_compensation(_: bool, input1: u64, input2: f32) -> Result<(), String> {
537 println!("Compensation input: {input1}, {input2}");
538 Ok(())
539 }
540
541 #[golem_operation(compensation=test_compensation_2)]
542 fn test_operation_2(input1: u64, input2: f32) -> Result<bool, String> {
543 println!("Op input: {input1}, {input2}");
544 Ok(true)
545 }
546
547 fn test_compensation_2(result: bool) -> Result<(), String> {
548 println!("Compensation for operation result {result:?}");
549 Ok(())
550 }
551
552 #[golem_operation(compensation=test_compensation_3)]
553 fn test_operation_3(input: String) -> Result<(), String> {
554 println!("Op input: {input}");
555 Ok(())
556 }
557
558 fn test_compensation_3() -> Result<(), String> {
559 println!("Compensation for operation, not using any input");
560 Ok(())
561 }
562
563 #[golem_operation(compensation=test_compensation_4)]
564 fn test_operation_4(input: u64) -> Result<(), String> {
565 println!("Op input: {input}");
566 Ok(())
567 }
568
569 fn test_compensation_4(_: (), input: u64) -> Result<(), String> {
570 println!("Compensation for operation with single input {input}");
571 Ok(())
572 }
573
574 #[test]
576 #[ignore]
577 async fn tx_test_1() {
578 let result = fallible_transaction(|tx| {
579 boxed(async move {
580 println!("Executing the annotated function as an operation directly");
581 tx.test_operation(1, 0.1).await?;
582 tx.test_operation_2(1, 0.1).await?;
583 tx.test_operation_3("test".to_string()).await?;
584 tx.test_operation_4(1).await?;
585
586 Ok(11)
587 })
588 })
589 .await;
590
591 println!("{result:?}");
592 }
593
594 #[test]
596 #[ignore]
597 async fn tx_test_2() {
598 let result = infallible_transaction(|tx| {
599 boxed(async move {
600 println!("Executing the annotated function as an operation directly");
601 let _ = tx.test_operation(1, 0.1).await;
602 11
603 })
604 })
605 .await;
606
607 println!("{result:?}");
608 }
609}