Skip to main content

ic_sqlite_vfs/db/
transaction.rs

1//! Synchronous transaction wrapper for update canister methods.
2//!
3//! The closure cannot be async, so SQLite state cannot be held across an
4//! inter-canister call or any other `await` point.
5
6use crate::db::connection::Connection;
7use crate::db::DbError;
8use crate::sqlite_vfs::stable_blob;
9use std::ops::Deref;
10
11pub struct UpdateConnection<'connection> {
12    connection: &'connection Connection,
13    savepoint_id: u64,
14}
15
16impl<'connection> UpdateConnection<'connection> {
17    fn new(connection: &'connection Connection) -> Self {
18        Self {
19            connection,
20            savepoint_id: 0,
21        }
22    }
23
24    pub fn savepoint<T, F>(&mut self, f: F) -> Result<T, DbError>
25    where
26        F: FnOnce(&mut UpdateConnection<'connection>) -> Result<T, DbError>,
27    {
28        let name = self.next_savepoint_name();
29        self.connection
30            .execute_batch(&format!("SAVEPOINT {name}"))?;
31        match f(self) {
32            Ok(value) => {
33                self.connection
34                    .execute_batch(&format!("RELEASE SAVEPOINT {name}"))?;
35                Ok(value)
36            }
37            Err(error) => {
38                let _ = self
39                    .connection
40                    .execute_batch(&format!("ROLLBACK TRANSACTION TO SAVEPOINT {name}"));
41                let _ = self
42                    .connection
43                    .execute_batch(&format!("RELEASE SAVEPOINT {name}"));
44                Err(error)
45            }
46        }
47    }
48
49    fn next_savepoint_name(&mut self) -> String {
50        let id = self.savepoint_id;
51        self.savepoint_id += 1;
52        format!("__ic_sqlite_sp_{id}")
53    }
54}
55
56impl Deref for UpdateConnection<'_> {
57    type Target = Connection;
58
59    fn deref(&self) -> &Self::Target {
60        self.connection
61    }
62}
63
64pub fn run_immediate<T, F>(connection: &Connection, f: F) -> Result<T, DbError>
65where
66    F: FnOnce(&mut UpdateConnection<'_>) -> Result<T, DbError>,
67{
68    connection.execute_batch_nul_terminated(BEGIN_SQL)?;
69    let mut update_connection = UpdateConnection::new(connection);
70    match f(&mut update_connection) {
71        Ok(value) => {
72            if let Err(error) = connection.execute_batch_nul_terminated(COMMIT_SQL) {
73                let _ = connection.execute_batch_nul_terminated(ROLLBACK_SQL);
74                return Err(error);
75            }
76            stable_blob::commit_update()?;
77            Ok(value)
78        }
79        Err(error) => {
80            let _ = connection.execute_batch_nul_terminated(ROLLBACK_SQL);
81            stable_blob::rollback_update();
82            Err(error)
83        }
84    }
85}
86
87const BEGIN_SQL: &[u8] = b"BEGIN\0";
88const COMMIT_SQL: &[u8] = b"COMMIT\0";
89const ROLLBACK_SQL: &[u8] = b"ROLLBACK\0";
90
91#[cfg(test)]
92mod tests {
93    use super::{BEGIN_SQL, COMMIT_SQL, ROLLBACK_SQL};
94
95    #[test]
96    fn transaction_sql_is_nul_terminated() {
97        assert_eq!(BEGIN_SQL.last(), Some(&0));
98        assert_eq!(COMMIT_SQL.last(), Some(&0));
99        assert_eq!(ROLLBACK_SQL.last(), Some(&0));
100        assert!(std::ffi::CStr::from_bytes_with_nul(BEGIN_SQL).is_ok());
101        assert!(std::ffi::CStr::from_bytes_with_nul(COMMIT_SQL).is_ok());
102        assert!(std::ffi::CStr::from_bytes_with_nul(ROLLBACK_SQL).is_ok());
103    }
104}