1pub trait Aux {
12 fn commit(self) -> bool;
19 fn rollback(self);
21}
22
23pub(crate) struct NoAux;
25
26impl Aux for NoAux {
27 fn commit(self) -> bool {
28 true
29 }
30 fn rollback(self) {}
31}
32
33#[cfg(test)]
34mod test {
35 use std::{
36 sync::{Arc, Mutex},
37 thread,
38 };
39
40 use crate::auxtx::*;
41 use crate::{
42 abort, atomically, atomically_aux, atomically_or_err_aux, retry, test::TestError1, TVar,
43 };
44
45 #[derive(Clone)]
46 struct TestAuxDb {
47 counter: Arc<Mutex<i32>>,
48 }
49 struct TestAuxTx<'a> {
50 db: &'a TestAuxDb,
51 counter: i32,
52 finished: bool,
53 }
54
55 impl TestAuxDb {
56 fn begin(&self) -> TestAuxTx {
57 let guard = self.counter.lock().unwrap();
58 TestAuxTx {
59 db: self,
60 counter: *guard,
61 finished: false,
62 }
63 }
64 }
65
66 impl<'a> Aux for TestAuxTx<'a> {
67 fn commit(mut self) -> bool {
68 let mut guard = self.db.counter.lock().unwrap();
69 *guard = self.counter;
70 self.finished = true;
71 true
72 }
73 fn rollback(mut self) {
74 self.finished = true;
75 }
76 }
77
78 impl Drop for TestAuxTx<'_> {
79 fn drop(&mut self) {
80 if !self.finished && !thread::panicking() {
81 panic!("Transaction prematurely dropped. Must call `.commit()` or `.rollback()`.");
82 }
83 }
84 }
85
86 impl TestAuxDb {
87 fn new() -> TestAuxDb {
88 TestAuxDb {
89 counter: Arc::new(Mutex::new(0)),
90 }
91 }
92
93 fn counter(&self) -> i32 {
94 *self.counter.lock().unwrap()
95 }
96 }
97
98 #[tokio::test]
99 async fn aux_commit_rollback() {
100 let db = TestAuxDb::new();
101
102 atomically_or_err_aux::<_, TestError1, _, _, _>(
103 || db.begin(),
104 |atx| {
105 atx.counter = 1;
106 abort(TestError1)?;
107 Ok(())
108 },
109 )
110 .await
111 .expect_err("Should be aborted");
112
113 assert_eq!(db.counter(), 0);
114
115 atomically_aux(
116 || db.begin(),
117 |atx| {
118 atx.counter = 1;
119 Ok(())
120 },
121 )
122 .await;
123
124 assert_eq!(db.counter(), 1);
125
126 let ta = TVar::new(42);
127 let dbc = db.clone();
128 let tac = ta.clone();
129 let (sender, mut receiver) = tokio::sync::mpsc::unbounded_channel();
130 let handle = tokio::spawn(async move {
131 atomically_aux(
132 || dbc.begin(),
133 |atx| {
134 let a = tac.read()?;
135 atx.counter += *a;
136 if *a == 42 {
137 sender.send(()).unwrap();
139 retry()?;
140 }
141 Ok(())
142 },
143 )
144 .await;
145 });
146
147 let _ = receiver.recv().await;
148 assert_eq!(db.counter(), 1);
149
150 atomically(|| ta.write(10)).await;
152 handle.await.unwrap();
153
154 assert_eq!(db.counter(), 11);
155 }
156}