async_stm/
auxtx.rs

1/// An auxiliary transaction that gets co-committed with the STM transaction,
2/// unless the STM transaction is aborted or has to be retried, in which case
3/// the auxiliary transaction is rolled back.
4///
5/// The auxiliary transaction can also signal that its is unable to be committed,
6/// in which case the whole atomic transaction will be retried.
7///
8/// The database is not expected to return an error here, because of how `Transaction::commit` works.
9/// If there is a failure that needs to be surfaced, at the moment the database would have to buffer
10/// that error and return it on a subsequent operation, mapped to an `StmError::Abort`.
11pub trait Aux {
12    /// Commit the auxiliary transaction if the STM transaction did not detect any errors.
13    /// The STM transaction is checked first, because committing the database involves IO
14    /// and is expected to be slower.
15    ///
16    /// Return `false` if there are write conflicts in the persistent database itself,
17    /// to cause a complete retry for the whole transaction.
18    fn commit(self) -> bool;
19    /// Rollback the auxiliary transaction if the STM transaction was aborted, or it's going to be retried.
20    fn rollback(self);
21}
22
23/// Empty implementation for when we are not using any auxiliary transaction.
24pub(crate) struct NoAux;
25
26impl Aux for NoAux {
27    fn commit(self) -> bool {
28        true
29    }
30    fn rollback(self) {}
31}
32
33#[cfg(test)]
34mod test {
35    use std::{
36        sync::{Arc, Mutex},
37        thread,
38    };
39
40    use crate::auxtx::*;
41    use crate::{
42        abort, atomically, atomically_aux, atomically_or_err_aux, retry, test::TestError1, TVar,
43    };
44
45    #[derive(Clone)]
46    struct TestAuxDb {
47        counter: Arc<Mutex<i32>>,
48    }
49    struct TestAuxTx<'a> {
50        db: &'a TestAuxDb,
51        counter: i32,
52        finished: bool,
53    }
54
55    impl TestAuxDb {
56        fn begin(&self) -> TestAuxTx {
57            let guard = self.counter.lock().unwrap();
58            TestAuxTx {
59                db: self,
60                counter: *guard,
61                finished: false,
62            }
63        }
64    }
65
66    impl<'a> Aux for TestAuxTx<'a> {
67        fn commit(mut self) -> bool {
68            let mut guard = self.db.counter.lock().unwrap();
69            *guard = self.counter;
70            self.finished = true;
71            true
72        }
73        fn rollback(mut self) {
74            self.finished = true;
75        }
76    }
77
78    impl Drop for TestAuxTx<'_> {
79        fn drop(&mut self) {
80            if !self.finished && !thread::panicking() {
81                panic!("Transaction prematurely dropped. Must call `.commit()` or `.rollback()`.");
82            }
83        }
84    }
85
86    impl TestAuxDb {
87        fn new() -> TestAuxDb {
88            TestAuxDb {
89                counter: Arc::new(Mutex::new(0)),
90            }
91        }
92
93        fn counter(&self) -> i32 {
94            *self.counter.lock().unwrap()
95        }
96    }
97
98    #[tokio::test]
99    async fn aux_commit_rollback() {
100        let db = TestAuxDb::new();
101
102        atomically_or_err_aux::<_, TestError1, _, _, _>(
103            || db.begin(),
104            |atx| {
105                atx.counter = 1;
106                abort(TestError1)?;
107                Ok(())
108            },
109        )
110        .await
111        .expect_err("Should be aborted");
112
113        assert_eq!(db.counter(), 0);
114
115        atomically_aux(
116            || db.begin(),
117            |atx| {
118                atx.counter = 1;
119                Ok(())
120            },
121        )
122        .await;
123
124        assert_eq!(db.counter(), 1);
125
126        let ta = TVar::new(42);
127        let dbc = db.clone();
128        let tac = ta.clone();
129        let (sender, mut receiver) = tokio::sync::mpsc::unbounded_channel();
130        let handle = tokio::spawn(async move {
131            atomically_aux(
132                || dbc.begin(),
133                |atx| {
134                    let a = tac.read()?;
135                    atx.counter += *a;
136                    if *a == 42 {
137                        // Signal that we are entering the retry.
138                        sender.send(()).unwrap();
139                        retry()?;
140                    }
141                    Ok(())
142                },
143            )
144            .await;
145        });
146
147        let _ = receiver.recv().await;
148        assert_eq!(db.counter(), 1);
149
150        // Writing a value to `ta` will trigger the retry.
151        atomically(|| ta.write(10)).await;
152        handle.await.unwrap();
153
154        assert_eq!(db.counter(), 11);
155    }
156}