Skip to main content

sqlite3_ext/
transaction.rs

1use super::{types::*, Connection};
2
3/// The type of transaction to create.
4pub enum TransactionType {
5    /// The transaction does not actually start until the database is first accessed. If the first
6    /// statement in the transaction is a SELECT, then a read transaction is started. Subsequent
7    /// write statements will upgrade the transaction to a write transaction if possible, or return
8    /// SQLITE_BUSY. If the first statement in the transaction is a write statement, then a write
9    /// transaction is started.
10    Deferred,
11    /// The database connection starts a new write immediately, without waiting for a write
12    /// statement. The transaction method might fail with SQLITE_BUSY if another write transaction
13    /// is already active on another database connection.
14    Immediate,
15    /// Exclusive is similar to Immediate in that a write transaction is started immediately.
16    /// Exclusive and Immediate are the same in WAL mode, but in other journaling modes, Exclusive
17    /// prevents other database connections from reading the database while the transaction is
18    /// underway.
19    Exclusive,
20}
21
22#[derive(Debug, PartialEq, Eq)]
23enum TransactionState {
24    ActiveTransaction,
25    ActiveSavepoint,
26    Inactive,
27}
28
29/// A RAII wrapper around transactions.
30///
31/// The transaction can be created using [Connection::transaction], and finalized using
32/// [commit](Transaction::commit) or [rollback](Transaction::rollback). If no finalization method
33/// is used, the transaction will automatically be rolled back when it is dropped. For convenience,
34/// Transaction derefs to Connection.
35///
36/// Note that transactions affect the entire database connection. If other threads need to be
37/// prevented from interfering with a transaction, [Connection::lock] should be used.
38///
39/// This interface is optional. It's permitted to execute BEGIN, COMMIT, and ROLLBACK statements
40/// directly, without using this interface. This interface
41#[derive(Debug)]
42pub struct Transaction<'db> {
43    db: &'db Connection,
44    state: TransactionState,
45}
46
47impl Connection {
48    /// Starts a new transaction with the specified behavior.
49    pub fn transaction(&self, tt: TransactionType) -> Result<Transaction<'_>> {
50        let mut txn = Transaction {
51            db: self,
52            state: TransactionState::Inactive,
53        };
54        txn.start(tt)?;
55        Ok(txn)
56    }
57}
58
59impl<'db> Transaction<'db> {
60    fn start(&mut self, tt: TransactionType) -> Result<()> {
61        let sql = match tt {
62            TransactionType::Deferred => "BEGIN",
63            TransactionType::Immediate => "BEGIN IMMEDIATE",
64            TransactionType::Exclusive => "BEGIN EXCLUSIVE",
65        };
66        self.execute(sql, ())?;
67        self.state = TransactionState::ActiveTransaction;
68        Ok(())
69    }
70
71    /// Consumes the Transaction, committing it.
72    pub fn commit(mut self) -> Result<&'db Connection> {
73        self.commit_mut().map(|_| self.db)
74    }
75
76    fn commit_mut(&mut self) -> Result<()> {
77        let ret = match self.state {
78            TransactionState::ActiveTransaction => self.execute("COMMIT", ()),
79            TransactionState::ActiveSavepoint => self.execute("RELEASE SAVEPOINT a", ()),
80            TransactionState::Inactive => panic!("lifetime error"),
81        };
82        self.state = TransactionState::Inactive;
83        ret.map(|_| ())
84    }
85
86    /// Consumes the Transaction, rolling it back.
87    pub fn rollback(mut self) -> Result<&'db Connection> {
88        self.rollback_mut().map(|_| self.db)
89    }
90
91    fn rollback_mut(&mut self) -> Result<()> {
92        let ret = match self.state {
93            TransactionState::ActiveTransaction => self.execute("ROLLBACK", ()),
94            TransactionState::ActiveSavepoint => self.execute("ROLLBACK TO a", ()),
95            TransactionState::Inactive => panic!("lifetime error"),
96        };
97        self.state = TransactionState::Inactive;
98        ret.map(|_| ())
99    }
100
101    /// Create a savepoint for the current transaction. This functions identically to a
102    /// transaction, but committing or rolling back will only affect statements since the savepoint
103    /// was created.
104    pub fn savepoint(&mut self) -> Result<Transaction<'_>> {
105        self.execute("SAVEPOINT a", ())?;
106        let txn = Self {
107            db: self.db,
108            state: TransactionState::ActiveSavepoint,
109        };
110        Ok(txn)
111    }
112}
113
114impl std::ops::Deref for Transaction<'_> {
115    type Target = Connection;
116
117    fn deref(&self) -> &Connection {
118        self.db
119    }
120}
121
122impl Drop for Transaction<'_> {
123    fn drop(&mut self) {
124        if self.state != TransactionState::Inactive {
125            if let Err(e) = self.rollback_mut() {
126                if std::thread::panicking() {
127                    eprintln!("Error while closing SQLite transaction: {e:?}");
128                } else {
129                    panic!("Error while closing SQLite transaction: {e:?}");
130                }
131            }
132        }
133    }
134}
135
136#[cfg(all(test, feature = "static"))]
137mod test {
138    use crate::test_helpers::prelude::*;
139
140    #[test]
141    fn commit() -> Result<()> {
142        let h = TestHelpers::new();
143        h.db.execute("CREATE TABLE tbl(col)", ())?;
144        let txn = h.db.transaction(TransactionType::Deferred)?;
145        txn.execute("INSERT INTO tbl VALUES (1)", ())?;
146        txn.commit()?;
147        let count =
148            h.db.query_row("SELECT COUNT(*) FROM tbl", (), |r| Ok(r[0].get_i64()))?;
149        assert_eq!(count, 1);
150        Ok(())
151    }
152
153    #[test]
154    fn rollback() -> Result<()> {
155        let h = TestHelpers::new();
156        h.db.execute("CREATE TABLE tbl(col)", ())?;
157        let txn = h.db.transaction(TransactionType::Deferred)?;
158        txn.execute("INSERT INTO tbl VALUES (1)", ())?;
159        txn.rollback()?;
160        let count =
161            h.db.query_row("SELECT COUNT(*) FROM tbl", (), |r| Ok(r[0].get_i64()))?;
162        assert_eq!(count, 0);
163        Ok(())
164    }
165
166    #[test]
167    fn drop() -> Result<()> {
168        let h = TestHelpers::new();
169        h.db.execute("CREATE TABLE tbl(col)", ())?;
170        {
171            let txn = h.db.transaction(TransactionType::Deferred)?;
172            txn.execute("INSERT INTO tbl VALUES (1)", ())?;
173        }
174        let count =
175            h.db.query_row("SELECT COUNT(*) FROM tbl", (), |r| Ok(r[0].get_i64()))?;
176        assert_eq!(count, 0);
177        Ok(())
178    }
179
180    #[test]
181    fn savepoint_commit() -> Result<()> {
182        let h = TestHelpers::new();
183        h.db.execute("CREATE TABLE tbl(col)", ())?;
184        let mut txn = h.db.transaction(TransactionType::Deferred)?;
185        txn.execute("INSERT INTO tbl VALUES (1)", ())?;
186        let sp = txn.savepoint()?;
187        sp.execute("INSERT INTO tbl VALUES (2)", ())?;
188        sp.commit()?;
189        txn.commit()?;
190        let count =
191            h.db.query_row("SELECT COUNT(*) FROM tbl", (), |r| Ok(r[0].get_i64()))?;
192        assert_eq!(count, 2);
193        Ok(())
194    }
195
196    #[test]
197    fn savepoint_rollback() -> Result<()> {
198        let h = TestHelpers::new();
199        h.db.execute("CREATE TABLE tbl(col)", ())?;
200        let mut txn = h.db.transaction(TransactionType::Deferred)?;
201        txn.execute("INSERT INTO tbl VALUES (1)", ())?;
202        let sp = txn.savepoint()?;
203        sp.execute("INSERT INTO tbl VALUES (2)", ())?;
204        sp.rollback()?;
205        txn.commit()?;
206        let count =
207            h.db.query_row("SELECT COUNT(*) FROM tbl", (), |r| Ok(r[0].get_i64()))?;
208        assert_eq!(count, 1);
209        Ok(())
210    }
211
212    #[test]
213    fn savepoint_drop() -> Result<()> {
214        let h = TestHelpers::new();
215        h.db.execute("CREATE TABLE tbl(col)", ())?;
216        let mut txn = h.db.transaction(TransactionType::Deferred)?;
217        txn.execute("INSERT INTO tbl VALUES (1)", ())?;
218        {
219            let sp = txn.savepoint()?;
220            sp.execute("INSERT INTO tbl VALUES (2)", ())?;
221        }
222        txn.commit()?;
223        let count =
224            h.db.query_row("SELECT COUNT(*) FROM tbl", (), |r| Ok(r[0].get_i64()))?;
225        assert_eq!(count, 1);
226        Ok(())
227    }
228
229    #[test]
230    fn commit_fail() -> Result<()> {
231        let h = TestHelpers::new();
232        h.db.execute("CREATE TABLE tbl(col)", ())?;
233        let txn = h.db.transaction(TransactionType::Deferred)?;
234        txn.execute("ROLLBACK", ())?;
235        match txn.commit() {
236            Ok(_) => unreachable!(),
237            Err(e) => assert_eq!(e.to_string(), "cannot commit - no transaction is active"),
238        }
239        Ok(())
240    }
241}