use crate::db::connection::Connection;
use crate::db::DbError;
use crate::sqlite_vfs::stable_blob;
use std::ops::Deref;
pub struct UpdateConnection<'connection> {
connection: &'connection Connection,
savepoint_id: u64,
}
impl<'connection> UpdateConnection<'connection> {
fn new(connection: &'connection Connection) -> Self {
Self {
connection,
savepoint_id: 0,
}
}
pub fn savepoint<T, F>(&mut self, f: F) -> Result<T, DbError>
where
F: FnOnce(&mut UpdateConnection<'connection>) -> Result<T, DbError>,
{
let name = self.next_savepoint_name();
self.connection
.execute_batch(&format!("SAVEPOINT {name}"))?;
match f(self) {
Ok(value) => {
self.connection
.execute_batch(&format!("RELEASE SAVEPOINT {name}"))?;
Ok(value)
}
Err(error) => {
let _ = self
.connection
.execute_batch(&format!("ROLLBACK TRANSACTION TO SAVEPOINT {name}"));
let _ = self
.connection
.execute_batch(&format!("RELEASE SAVEPOINT {name}"));
Err(error)
}
}
}
fn next_savepoint_name(&mut self) -> String {
let id = self.savepoint_id;
self.savepoint_id += 1;
format!("__ic_sqlite_sp_{id}")
}
}
impl Deref for UpdateConnection<'_> {
type Target = Connection;
fn deref(&self) -> &Self::Target {
self.connection
}
}
pub fn run_immediate<T, F>(connection: &Connection, f: F) -> Result<T, DbError>
where
F: FnOnce(&mut UpdateConnection<'_>) -> Result<T, DbError>,
{
connection.execute_batch_nul_terminated(BEGIN_SQL)?;
let mut update_connection = UpdateConnection::new(connection);
match f(&mut update_connection) {
Ok(value) => {
if let Err(error) = connection.execute_batch_nul_terminated(COMMIT_SQL) {
let _ = connection.execute_batch_nul_terminated(ROLLBACK_SQL);
return Err(error);
}
stable_blob::commit_update()?;
Ok(value)
}
Err(error) => {
let _ = connection.execute_batch_nul_terminated(ROLLBACK_SQL);
stable_blob::rollback_update();
Err(error)
}
}
}
const BEGIN_SQL: &[u8] = b"BEGIN\0";
const COMMIT_SQL: &[u8] = b"COMMIT\0";
const ROLLBACK_SQL: &[u8] = b"ROLLBACK\0";
#[cfg(test)]
mod tests {
use super::{BEGIN_SQL, COMMIT_SQL, ROLLBACK_SQL};
#[test]
fn transaction_sql_is_nul_terminated() {
assert_eq!(BEGIN_SQL.last(), Some(&0));
assert_eq!(COMMIT_SQL.last(), Some(&0));
assert_eq!(ROLLBACK_SQL.last(), Some(&0));
assert!(std::ffi::CStr::from_bytes_with_nul(BEGIN_SQL).is_ok());
assert!(std::ffi::CStr::from_bytes_with_nul(COMMIT_SQL).is_ok());
assert!(std::ffi::CStr::from_bytes_with_nul(ROLLBACK_SQL).is_ok());
}
}