ic_sqlite_vfs/db/
transaction.rs1use crate::db::connection::Connection;
7use crate::db::DbError;
8use crate::sqlite_vfs::stable_blob;
9use std::ops::Deref;
10
11pub struct UpdateConnection<'connection> {
12 connection: &'connection Connection,
13 savepoint_id: u64,
14}
15
16impl<'connection> UpdateConnection<'connection> {
17 fn new(connection: &'connection Connection) -> Self {
18 Self {
19 connection,
20 savepoint_id: 0,
21 }
22 }
23
24 pub fn savepoint<T, F>(&mut self, f: F) -> Result<T, DbError>
25 where
26 F: FnOnce(&mut UpdateConnection<'connection>) -> Result<T, DbError>,
27 {
28 let name = self.next_savepoint_name();
29 self.connection
30 .execute_batch(&format!("SAVEPOINT {name}"))?;
31 match f(self) {
32 Ok(value) => {
33 self.connection
34 .execute_batch(&format!("RELEASE SAVEPOINT {name}"))?;
35 Ok(value)
36 }
37 Err(error) => {
38 let _ = self
39 .connection
40 .execute_batch(&format!("ROLLBACK TRANSACTION TO SAVEPOINT {name}"));
41 let _ = self
42 .connection
43 .execute_batch(&format!("RELEASE SAVEPOINT {name}"));
44 Err(error)
45 }
46 }
47 }
48
49 fn next_savepoint_name(&mut self) -> String {
50 let id = self.savepoint_id;
51 self.savepoint_id += 1;
52 format!("__ic_sqlite_sp_{id}")
53 }
54}
55
56impl Deref for UpdateConnection<'_> {
57 type Target = Connection;
58
59 fn deref(&self) -> &Self::Target {
60 self.connection
61 }
62}
63
64pub fn run_immediate<T, F>(connection: &Connection, f: F) -> Result<T, DbError>
65where
66 F: FnOnce(&mut UpdateConnection<'_>) -> Result<T, DbError>,
67{
68 connection.execute_batch_nul_terminated(BEGIN_SQL)?;
69 let mut update_connection = UpdateConnection::new(connection);
70 match f(&mut update_connection) {
71 Ok(value) => {
72 if let Err(error) = connection.execute_batch_nul_terminated(COMMIT_SQL) {
73 let _ = connection.execute_batch_nul_terminated(ROLLBACK_SQL);
74 return Err(error);
75 }
76 stable_blob::commit_update()?;
77 Ok(value)
78 }
79 Err(error) => {
80 let _ = connection.execute_batch_nul_terminated(ROLLBACK_SQL);
81 stable_blob::rollback_update();
82 Err(error)
83 }
84 }
85}
86
87const BEGIN_SQL: &[u8] = b"BEGIN\0";
88const COMMIT_SQL: &[u8] = b"COMMIT\0";
89const ROLLBACK_SQL: &[u8] = b"ROLLBACK\0";
90
91#[cfg(test)]
92mod tests {
93 use super::{BEGIN_SQL, COMMIT_SQL, ROLLBACK_SQL};
94
95 #[test]
96 fn transaction_sql_is_nul_terminated() {
97 assert_eq!(BEGIN_SQL.last(), Some(&0));
98 assert_eq!(COMMIT_SQL.last(), Some(&0));
99 assert_eq!(ROLLBACK_SQL.last(), Some(&0));
100 assert!(std::ffi::CStr::from_bytes_with_nul(BEGIN_SQL).is_ok());
101 assert!(std::ffi::CStr::from_bytes_with_nul(COMMIT_SQL).is_ok());
102 assert!(std::ffi::CStr::from_bytes_with_nul(ROLLBACK_SQL).is_ok());
103 }
104}