use fsqlite_error::FrankenError;
use fsqlite_types::value::SqliteValue;
use crate::{Connection, Row};
use super::params::ParamValue;
pub struct Transaction<'a> {
conn: &'a Connection,
finalized: bool,
}
impl<'a> Transaction<'a> {
fn new(conn: &'a Connection) -> Result<Self, FrankenError> {
conn.begin_transaction()?;
Ok(Self {
conn,
finalized: false,
})
}
pub fn commit(&mut self) -> Result<(), FrankenError> {
self.conn.commit_transaction()?;
self.finalized = true;
Ok(())
}
pub fn rollback(&mut self) -> Result<(), FrankenError> {
self.conn.rollback_transaction()?;
self.finalized = true;
Ok(())
}
pub fn execute(&self, sql: &str) -> Result<usize, FrankenError> {
self.conn.execute(sql)
}
pub fn execute_with_params(
&self,
sql: &str,
params: &[SqliteValue],
) -> Result<usize, FrankenError> {
self.conn.execute_with_params(sql, params)
}
pub fn execute_compat(&self, sql: &str, params: &[ParamValue]) -> Result<usize, FrankenError> {
let values: Vec<SqliteValue> = params.iter().map(|p| p.0.clone()).collect();
self.conn.execute_with_params(sql, &values)
}
pub fn query(&self, sql: &str) -> Result<Vec<Row>, FrankenError> {
self.conn.query(sql)
}
pub fn query_with_params(
&self,
sql: &str,
params: &[SqliteValue],
) -> Result<Vec<Row>, FrankenError> {
self.conn.query_with_params(sql, params)
}
pub fn query_params(&self, sql: &str, params: &[ParamValue]) -> Result<Vec<Row>, FrankenError> {
let values: Vec<SqliteValue> = params.iter().map(|p| p.0.clone()).collect();
self.conn.query_with_params(sql, &values)
}
pub fn query_row(&self, sql: &str) -> Result<Row, FrankenError> {
self.conn.query_row(sql)
}
pub fn query_row_with_params(
&self,
sql: &str,
params: &[SqliteValue],
) -> Result<Row, FrankenError> {
self.conn.query_row_with_params(sql, params)
}
pub fn query_row_map<T, F>(
&self,
sql: &str,
params: &[ParamValue],
f: F,
) -> Result<T, FrankenError>
where
F: FnOnce(&Row) -> Result<T, FrankenError>,
{
let values: Vec<SqliteValue> = params.iter().map(|p| p.0.clone()).collect();
let row = self.conn.query_row_with_params(sql, &values)?;
f(&row)
}
pub fn query_map_collect<T, F>(
&self,
sql: &str,
params: &[ParamValue],
mut f: F,
) -> Result<Vec<T>, FrankenError>
where
F: FnMut(&Row) -> Result<T, FrankenError>,
{
let values: Vec<SqliteValue> = params.iter().map(|p| p.0.clone()).collect();
let mut mapped = Vec::new();
self.conn.query_with_params_for_each(sql, &values, |row| {
mapped.push(f(row)?);
Ok(())
})?;
Ok(mapped)
}
pub fn execute_batch(&self, sql: &str) -> Result<(), FrankenError> {
Connection::execute_batch(self.conn, sql)
}
pub fn last_insert_rowid(&self) -> Result<i64, FrankenError> {
Ok(self.conn.last_insert_rowid())
}
}
impl Drop for Transaction<'_> {
fn drop(&mut self) {
if !self.finalized {
let _ = self.conn.rollback_transaction();
}
}
}
pub trait TransactionExt {
fn transaction(&self) -> Result<Transaction<'_>, FrankenError>;
}
impl TransactionExt for Connection {
fn transaction(&self) -> Result<Transaction<'_>, FrankenError> {
Transaction::new(self)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::compat::RowExt;
#[test]
fn transaction_commit() {
let conn = Connection::open(":memory:").unwrap();
conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY, val TEXT)")
.unwrap();
let mut tx = conn.transaction().unwrap();
tx.execute("INSERT INTO t (val) VALUES ('committed')")
.unwrap();
tx.commit().unwrap();
let rows = conn.query("SELECT val FROM t").unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get_typed::<String>(0).unwrap(), "committed");
}
#[test]
fn transaction_rollback_on_drop() {
let conn = Connection::open(":memory:").unwrap();
conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY, val TEXT)")
.unwrap();
{
let tx = conn.transaction().unwrap();
tx.execute("INSERT INTO t (val) VALUES ('rolled_back')")
.unwrap();
}
let rows = conn.query("SELECT val FROM t").unwrap();
assert!(rows.is_empty());
}
#[test]
fn transaction_explicit_rollback() {
let conn = Connection::open(":memory:").unwrap();
conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY, val TEXT)")
.unwrap();
let mut tx = conn.transaction().unwrap();
tx.execute("INSERT INTO t (val) VALUES ('rolled_back')")
.unwrap();
tx.rollback().unwrap();
let rows = conn.query("SELECT val FROM t").unwrap();
assert!(rows.is_empty());
}
}