sqlite3_ext/
transaction.rs1use super::{types::*, Connection};
2
3pub enum TransactionType {
5 Deferred,
11 Immediate,
15 Exclusive,
20}
21
22#[derive(Debug, PartialEq, Eq)]
23enum TransactionState {
24 ActiveTransaction,
25 ActiveSavepoint,
26 Inactive,
27}
28
29#[derive(Debug)]
42pub struct Transaction<'db> {
43 db: &'db Connection,
44 state: TransactionState,
45}
46
47impl Connection {
48 pub fn transaction(&self, tt: TransactionType) -> Result<Transaction<'_>> {
50 let mut txn = Transaction {
51 db: self,
52 state: TransactionState::Inactive,
53 };
54 txn.start(tt)?;
55 Ok(txn)
56 }
57}
58
59impl<'db> Transaction<'db> {
60 fn start(&mut self, tt: TransactionType) -> Result<()> {
61 let sql = match tt {
62 TransactionType::Deferred => "BEGIN",
63 TransactionType::Immediate => "BEGIN IMMEDIATE",
64 TransactionType::Exclusive => "BEGIN EXCLUSIVE",
65 };
66 self.execute(sql, ())?;
67 self.state = TransactionState::ActiveTransaction;
68 Ok(())
69 }
70
71 pub fn commit(mut self) -> Result<&'db Connection> {
73 self.commit_mut().map(|_| self.db)
74 }
75
76 fn commit_mut(&mut self) -> Result<()> {
77 let ret = match self.state {
78 TransactionState::ActiveTransaction => self.execute("COMMIT", ()),
79 TransactionState::ActiveSavepoint => self.execute("RELEASE SAVEPOINT a", ()),
80 TransactionState::Inactive => panic!("lifetime error"),
81 };
82 self.state = TransactionState::Inactive;
83 ret.map(|_| ())
84 }
85
86 pub fn rollback(mut self) -> Result<&'db Connection> {
88 self.rollback_mut().map(|_| self.db)
89 }
90
91 fn rollback_mut(&mut self) -> Result<()> {
92 let ret = match self.state {
93 TransactionState::ActiveTransaction => self.execute("ROLLBACK", ()),
94 TransactionState::ActiveSavepoint => self.execute("ROLLBACK TO a", ()),
95 TransactionState::Inactive => panic!("lifetime error"),
96 };
97 self.state = TransactionState::Inactive;
98 ret.map(|_| ())
99 }
100
101 pub fn savepoint(&mut self) -> Result<Transaction<'_>> {
105 self.execute("SAVEPOINT a", ())?;
106 let txn = Self {
107 db: self.db,
108 state: TransactionState::ActiveSavepoint,
109 };
110 Ok(txn)
111 }
112}
113
114impl std::ops::Deref for Transaction<'_> {
115 type Target = Connection;
116
117 fn deref(&self) -> &Connection {
118 self.db
119 }
120}
121
122impl Drop for Transaction<'_> {
123 fn drop(&mut self) {
124 if self.state != TransactionState::Inactive {
125 if let Err(e) = self.rollback_mut() {
126 if std::thread::panicking() {
127 eprintln!("Error while closing SQLite transaction: {e:?}");
128 } else {
129 panic!("Error while closing SQLite transaction: {e:?}");
130 }
131 }
132 }
133 }
134}
135
136#[cfg(all(test, feature = "static"))]
137mod test {
138 use crate::test_helpers::prelude::*;
139
140 #[test]
141 fn commit() -> Result<()> {
142 let h = TestHelpers::new();
143 h.db.execute("CREATE TABLE tbl(col)", ())?;
144 let txn = h.db.transaction(TransactionType::Deferred)?;
145 txn.execute("INSERT INTO tbl VALUES (1)", ())?;
146 txn.commit()?;
147 let count =
148 h.db.query_row("SELECT COUNT(*) FROM tbl", (), |r| Ok(r[0].get_i64()))?;
149 assert_eq!(count, 1);
150 Ok(())
151 }
152
153 #[test]
154 fn rollback() -> Result<()> {
155 let h = TestHelpers::new();
156 h.db.execute("CREATE TABLE tbl(col)", ())?;
157 let txn = h.db.transaction(TransactionType::Deferred)?;
158 txn.execute("INSERT INTO tbl VALUES (1)", ())?;
159 txn.rollback()?;
160 let count =
161 h.db.query_row("SELECT COUNT(*) FROM tbl", (), |r| Ok(r[0].get_i64()))?;
162 assert_eq!(count, 0);
163 Ok(())
164 }
165
166 #[test]
167 fn drop() -> Result<()> {
168 let h = TestHelpers::new();
169 h.db.execute("CREATE TABLE tbl(col)", ())?;
170 {
171 let txn = h.db.transaction(TransactionType::Deferred)?;
172 txn.execute("INSERT INTO tbl VALUES (1)", ())?;
173 }
174 let count =
175 h.db.query_row("SELECT COUNT(*) FROM tbl", (), |r| Ok(r[0].get_i64()))?;
176 assert_eq!(count, 0);
177 Ok(())
178 }
179
180 #[test]
181 fn savepoint_commit() -> Result<()> {
182 let h = TestHelpers::new();
183 h.db.execute("CREATE TABLE tbl(col)", ())?;
184 let mut txn = h.db.transaction(TransactionType::Deferred)?;
185 txn.execute("INSERT INTO tbl VALUES (1)", ())?;
186 let sp = txn.savepoint()?;
187 sp.execute("INSERT INTO tbl VALUES (2)", ())?;
188 sp.commit()?;
189 txn.commit()?;
190 let count =
191 h.db.query_row("SELECT COUNT(*) FROM tbl", (), |r| Ok(r[0].get_i64()))?;
192 assert_eq!(count, 2);
193 Ok(())
194 }
195
196 #[test]
197 fn savepoint_rollback() -> Result<()> {
198 let h = TestHelpers::new();
199 h.db.execute("CREATE TABLE tbl(col)", ())?;
200 let mut txn = h.db.transaction(TransactionType::Deferred)?;
201 txn.execute("INSERT INTO tbl VALUES (1)", ())?;
202 let sp = txn.savepoint()?;
203 sp.execute("INSERT INTO tbl VALUES (2)", ())?;
204 sp.rollback()?;
205 txn.commit()?;
206 let count =
207 h.db.query_row("SELECT COUNT(*) FROM tbl", (), |r| Ok(r[0].get_i64()))?;
208 assert_eq!(count, 1);
209 Ok(())
210 }
211
212 #[test]
213 fn savepoint_drop() -> Result<()> {
214 let h = TestHelpers::new();
215 h.db.execute("CREATE TABLE tbl(col)", ())?;
216 let mut txn = h.db.transaction(TransactionType::Deferred)?;
217 txn.execute("INSERT INTO tbl VALUES (1)", ())?;
218 {
219 let sp = txn.savepoint()?;
220 sp.execute("INSERT INTO tbl VALUES (2)", ())?;
221 }
222 txn.commit()?;
223 let count =
224 h.db.query_row("SELECT COUNT(*) FROM tbl", (), |r| Ok(r[0].get_i64()))?;
225 assert_eq!(count, 1);
226 Ok(())
227 }
228
229 #[test]
230 fn commit_fail() -> Result<()> {
231 let h = TestHelpers::new();
232 h.db.execute("CREATE TABLE tbl(col)", ())?;
233 let txn = h.db.transaction(TransactionType::Deferred)?;
234 txn.execute("ROLLBACK", ())?;
235 match txn.commit() {
236 Ok(_) => unreachable!(),
237 Err(e) => assert_eq!(e.to_string(), "cannot commit - no transaction is active"),
238 }
239 Ok(())
240 }
241}