use crate::core::{Error, Result, Row, Value};
use crate::executor::context::ExecutionContext;
use crate::executor::expression::ExpressionEval;
use crate::executor::result::ExecutorMemoryResult;
use crate::parser::ast::{Expression, Statement};
use crate::parser::Parser;
use crate::storage::expression::Expression as StorageExpression;
use crate::storage::traits::{QueryResult, Transaction as StorageTransaction};
use super::database::FromValue;
use super::params::Params;
use super::rows::Rows;
pub struct Transaction {
tx: Option<Box<dyn StorageTransaction>>,
committed: bool,
rolled_back: bool,
}
impl Transaction {
pub(crate) fn new(tx: Box<dyn StorageTransaction>) -> Self {
Self {
tx: Some(tx),
committed: false,
rolled_back: false,
}
}
fn check_active(&self) -> Result<()> {
if self.committed {
return Err(Error::TransactionEnded);
}
if self.rolled_back {
return Err(Error::TransactionEnded);
}
if self.tx.is_none() {
return Err(Error::TransactionNotStarted);
}
Ok(())
}
pub fn id(&self) -> i64 {
self.tx.as_ref().map(|tx| tx.id()).unwrap_or(-1)
}
pub fn execute<P: Params>(&mut self, sql: &str, params: P) -> Result<i64> {
self.check_active()?;
let param_values = params.into_params();
let result = self.execute_sql(sql, ¶m_values)?;
Ok(result.rows_affected())
}
pub fn query<P: Params>(&mut self, sql: &str, params: P) -> Result<Rows> {
self.check_active()?;
let param_values = params.into_params();
let result = self.execute_sql(sql, ¶m_values)?;
Ok(Rows::new(result))
}
pub fn query_one<T: FromValue, P: Params>(&mut self, sql: &str, params: P) -> Result<T> {
let row = self
.query(sql, params)?
.next()
.ok_or(Error::NoRowsReturned)??;
row.get(0)
}
pub fn query_opt<T: FromValue, P: Params>(
&mut self,
sql: &str,
params: P,
) -> Result<Option<T>> {
match self.query(sql, params)?.next() {
Some(row) => Ok(Some(row?.get(0)?)),
None => Ok(None),
}
}
fn execute_sql(&mut self, sql: &str, params: &[Value]) -> Result<Box<dyn QueryResult>> {
let mut parser = Parser::new(sql);
let program = parser
.parse_program()
.map_err(|e| Error::parse(e.to_string()))?;
let ctx = if params.is_empty() {
ExecutionContext::new()
} else {
ExecutionContext::with_params(params.to_vec())
};
let mut last_result: Option<Box<dyn QueryResult>> = None;
for statement in &program.statements {
last_result = Some(self.execute_statement(statement, &ctx)?);
}
last_result.ok_or(Error::NoStatementsToExecute)
}
fn execute_statement(
&mut self,
statement: &Statement,
ctx: &ExecutionContext,
) -> Result<Box<dyn QueryResult>> {
use crate::executor::result::ExecResult;
let tx = self.tx.as_mut().ok_or(Error::TransactionNotStarted)?;
match statement {
Statement::Insert(stmt) => {
let table_name = &stmt.table_name.value();
let mut table = tx.get_table(table_name)?;
let schema = table.schema().clone();
let mut total_inserted = 0i64;
for row_values in &stmt.values {
if row_values.len() > schema.columns.len() {
return Err(Error::InvalidArgumentMessage(format!(
"INSERT has {} columns but {} values",
schema.columns.len(),
row_values.len()
)));
}
let mut values = Vec::with_capacity(schema.columns.len());
for (i, expr) in row_values.iter().enumerate() {
let mut eval = ExpressionEval::compile(expr, &[])?.with_context(ctx);
let val = eval.eval_slice(&[])?;
let target_type = schema.columns[i].data_type;
values.push(val.coerce_to_type(target_type));
}
while values.len() < schema.columns.len() {
values.push(Value::null_unknown());
}
for fk in &schema.foreign_keys {
let fk_value = &values[fk.column_id];
if fk_value.is_null() {
continue;
}
let ref_table_name = fk.referenced_table.to_lowercase();
let ref_table = tx.get_table(&ref_table_name)?;
let ref_schema = ref_table.schema();
let mut expr = crate::storage::expression::ComparisonExpr::new(
fk.referenced_column_name.clone(),
crate::core::Operator::Eq,
fk_value.clone(),
);
StorageExpression::prepare_for_schema(&mut expr, ref_schema);
let mut scanner = ref_table.scan(&[0], Some(&expr))?;
if !scanner.next() {
return Err(Error::ReferentialIntegrityViolation {
message: format!(
"FOREIGN KEY constraint failed: value '{}' not present in {}({})",
fk_value, fk.referenced_table, fk.referenced_column_name
),
});
}
}
let row = Row::from_values(values);
let _ = table.insert(row)?;
total_inserted += 1;
}
Ok(Box::new(ExecResult::with_rows_affected(total_inserted)))
}
Statement::Update(stmt) => {
use crate::executor::expression::{
compile_expression, ExecuteContext, ExprVM, RowFilter, SharedProgram,
};
let table_name = &stmt.table_name.value();
let mut table = tx.get_table(table_name)?;
let columns: Vec<String> = table.schema().column_names_owned().to_vec();
let updates = stmt.updates.clone();
let where_filter: Option<RowFilter> = stmt
.where_clause
.as_ref()
.map(|expr| RowFilter::new(expr, &columns).map(|f| f.with_context(ctx)))
.transpose()?;
let compiled_updates: Vec<(usize, SharedProgram)> = updates
.iter()
.map(|(col_name, expr)| {
let idx = columns
.iter()
.position(|c| c.eq_ignore_ascii_case(col_name))
.ok_or_else(|| Error::ColumnNotFoundNamed(col_name.clone()))?;
let program = compile_expression(expr, &columns)?;
Ok((idx, program))
})
.collect::<Result<Vec<_>>>()?;
let mut vm = ExprVM::new();
let mut setter = |row: Row| -> Result<(Row, bool)> {
if let Some(ref filter) = where_filter {
if !filter.matches(&row) {
return Ok((row, false));
}
}
let row_data = row.as_slice();
let exec_ctx = ExecuteContext::new(row_data);
let mut updates_to_apply: Vec<(usize, Value)> =
Vec::with_capacity(compiled_updates.len());
for (idx, program) in compiled_updates.iter() {
match vm.execute(program, &exec_ctx) {
Ok(v) => updates_to_apply.push((*idx, v)),
Err(e) => {
return Err(e);
}
}
}
let mut new_values = row.into_values();
for (idx, value) in updates_to_apply {
new_values[idx] = value;
}
Ok((Row::from_values(new_values), true))
};
let updated_count = table.update(None, &mut setter)?;
Ok(Box::new(ExecResult::with_rows_affected(
updated_count as i64,
)))
}
Statement::Delete(stmt) => {
let table_name = &stmt.table_name.value();
let mut table = tx.get_table(table_name)?;
let where_expr = stmt
.where_clause
.as_ref()
.map(|expr| self.convert_to_storage_expression(expr, ctx))
.transpose()?;
let deleted_count = table.delete(where_expr.as_deref())?;
Ok(Box::new(ExecResult::with_rows_affected(
deleted_count as i64,
)))
}
Statement::Select(stmt) => {
let table_expr = match &stmt.table_expr {
Some(expr) => expr,
None => {
let mut columns = Vec::new();
let mut values = Vec::new();
for (i, col_expr) in stmt.columns.iter().enumerate() {
let col_name = match col_expr {
Expression::Aliased(a) => a.alias.value.clone(),
Expression::Identifier(id) => id.value.clone(),
_ => format!("expr{}", i + 1),
};
columns.push(col_name);
let mut eval =
ExpressionEval::compile(col_expr, &[])?.with_context(ctx);
values.push(eval.eval_slice(&[])?);
}
let rows = vec![Row::from_values(values)];
return Ok(Box::new(ExecutorMemoryResult::new(columns, rows)));
}
};
let table_name = match table_expr.as_ref() {
Expression::TableSource(ts) => &ts.name.value(),
Expression::Identifier(id) => &id.value,
_ => {
return Err(Error::NotSupportedMessage(
"Complex FROM clauses not supported in transactions".to_string(),
))
}
};
let table = tx.get_table(table_name)?;
let schema = table.schema();
let columns: Vec<String> = schema.column_names_owned().to_vec();
let column_indices: Vec<usize> = (0..columns.len()).collect();
let where_expr = stmt
.where_clause
.as_ref()
.map(|expr| self.convert_to_storage_expression(expr, ctx))
.transpose()?;
let mut scanner = table.scan(&column_indices, where_expr.as_deref())?;
let mut rows = Vec::new();
while scanner.next() {
rows.push(scanner.take_row());
}
if let Some(err) = scanner.err() {
return Err(err.clone());
}
let (result_columns, result_rows) = if stmt.columns.len() == 1 {
if let Expression::Star(_) = &stmt.columns[0] {
(columns, rows)
} else {
self.project_columns(stmt, &columns, rows, ctx)?
}
} else {
self.project_columns(stmt, &columns, rows, ctx)?
};
Ok(Box::new(ExecutorMemoryResult::new(
result_columns,
result_rows,
)))
}
_ => Err(Error::NotSupportedMessage(
"Only DML statements are supported in transactions".to_string(),
)),
}
}
#[allow(clippy::only_used_in_recursion)]
fn convert_to_storage_expression(
&self,
expr: &Expression,
ctx: &ExecutionContext,
) -> Result<Box<dyn crate::storage::expression::Expression>> {
use crate::core::Operator;
use crate::storage::expression::{AndExpr, ComparisonExpr, OrExpr};
match expr {
Expression::Infix(infix) => {
let op_str = infix.operator.as_str();
match op_str {
"AND" => {
let left = self.convert_to_storage_expression(&infix.left, ctx)?;
let right = self.convert_to_storage_expression(&infix.right, ctx)?;
return Ok(Box::new(AndExpr::and(left, right)));
}
"OR" => {
let left = self.convert_to_storage_expression(&infix.left, ctx)?;
let right = self.convert_to_storage_expression(&infix.right, ctx)?;
return Ok(Box::new(OrExpr::or(left, right)));
}
_ => {}
}
let op = match op_str {
"=" | "==" => Operator::Eq,
"!=" | "<>" => Operator::Ne,
"<" => Operator::Lt,
"<=" => Operator::Lte,
">" => Operator::Gt,
">=" => Operator::Gte,
_ => {
return Err(Error::NotSupportedMessage(format!(
"Operator {} not supported in transaction WHERE clause",
infix.operator
)));
}
};
let column = match infix.left.as_ref() {
Expression::Identifier(id) => id.value.clone(),
_ => {
return Err(Error::NotSupportedMessage(
"Only column references supported on left side of comparison"
.to_string(),
));
}
};
let value = ExpressionEval::compile(&infix.right, &[])?
.with_context(ctx)
.eval_slice(&[])?;
Ok(Box::new(ComparisonExpr::new(column, op, value)))
}
_ => Err(Error::NotSupportedMessage(format!(
"Expression type {:?} not supported in transaction WHERE clause",
expr
))),
}
}
fn project_columns(
&self,
stmt: &crate::parser::ast::SelectStatement,
source_columns: &[String],
rows: Vec<Row>,
ctx: &ExecutionContext,
) -> Result<(Vec<String>, Vec<Row>)> {
use crate::executor::expression::compile_expression;
let mut result_columns = Vec::new();
let mut result_rows = Vec::new();
let mut compiled_exprs: Vec<Option<crate::executor::expression::SharedProgram>> =
Vec::with_capacity(stmt.columns.len());
for (i, col_expr) in stmt.columns.iter().enumerate() {
match col_expr {
Expression::Star(_) => {
result_columns.extend(source_columns.iter().cloned());
compiled_exprs.push(None); }
Expression::Aliased(a) => {
result_columns.push(a.alias.value.clone());
compiled_exprs.push(Some(compile_expression(col_expr, source_columns)?));
}
Expression::Identifier(id) => {
result_columns.push(id.value.clone());
compiled_exprs.push(Some(compile_expression(col_expr, source_columns)?));
}
_ => {
result_columns.push(format!("expr{}", i + 1));
compiled_exprs.push(Some(compile_expression(col_expr, source_columns)?));
}
}
}
let params = ctx.params();
let named_params: rustc_hash::FxHashMap<String, Value> = ctx
.named_params()
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
let mut vm = crate::executor::expression::ExprVM::new();
let num_cols = stmt.columns.len();
for row in rows {
let row_data = row.as_slice();
let mut exec_ctx = crate::executor::expression::ExecuteContext::new(row_data);
if !params.is_empty() {
exec_ctx = exec_ctx.with_params(params);
}
if !named_params.is_empty() {
exec_ctx = exec_ctx.with_named_params(&named_params);
}
let mut values = Vec::with_capacity(num_cols.max(row.len()));
for compiled in &compiled_exprs {
match compiled {
None => {
values.extend_from_slice(row_data);
}
Some(program) => {
values.push(vm.execute(program, &exec_ctx)?);
}
}
}
result_rows.push(Row::from_values(values));
}
Ok((result_columns, result_rows))
}
pub fn commit(&mut self) -> Result<()> {
self.check_active()?;
if let Some(mut tx) = self.tx.take() {
tx.commit()?;
self.committed = true;
}
Ok(())
}
pub fn rollback(&mut self) -> Result<()> {
if self.committed {
return Err(Error::TransactionCommitted);
}
if self.rolled_back {
return Ok(()); }
if let Some(mut tx) = self.tx.take() {
tx.rollback()?;
self.rolled_back = true;
}
Ok(())
}
}
impl Drop for Transaction {
fn drop(&mut self) {
if !self.committed && !self.rolled_back {
let _ = self.rollback();
}
}
}
#[cfg(test)]
mod tests {
use crate::api::Database;
#[test]
fn test_transaction_commit() {
let db = Database::open_in_memory().unwrap();
db.execute(
"CREATE TABLE test (id INTEGER PRIMARY KEY, value INTEGER)",
(),
)
.unwrap();
db.execute("INSERT INTO test VALUES ($1, $2)", (1, 100))
.unwrap();
let value: i64 = db
.query_one("SELECT value FROM test WHERE id = $1", (1,))
.unwrap();
assert_eq!(value, 100);
}
#[test]
fn test_transaction_rollback() {
let db = Database::open_in_memory().unwrap();
db.execute(
"CREATE TABLE test (id INTEGER PRIMARY KEY, value INTEGER)",
(),
)
.unwrap();
db.execute("INSERT INTO test VALUES ($1, $2)", (1, 100))
.unwrap();
let mut tx = db.begin().unwrap();
tx.execute("UPDATE test SET value = $1 WHERE id = $2", (200, 1))
.unwrap();
tx.rollback().unwrap();
let value: i64 = db
.query_one("SELECT value FROM test WHERE id = $1", (1,))
.unwrap();
assert_eq!(value, 100);
}
#[test]
fn test_transaction_auto_rollback() {
let db = Database::open_in_memory().unwrap();
db.execute(
"CREATE TABLE test (id INTEGER PRIMARY KEY, value INTEGER)",
(),
)
.unwrap();
db.execute("INSERT INTO test VALUES ($1, $2)", (1, 100))
.unwrap();
{
let mut tx = db.begin().unwrap();
tx.execute("UPDATE test SET value = $1 WHERE id = $2", (200, 1))
.unwrap();
}
let value: i64 = db
.query_one("SELECT value FROM test WHERE id = $1", (1,))
.unwrap();
assert_eq!(value, 100);
}
#[test]
fn test_transaction_query() {
let db = Database::open_in_memory().unwrap();
db.execute(
"CREATE TABLE test (id INTEGER PRIMARY KEY, value INTEGER)",
(),
)
.unwrap();
db.execute("INSERT INTO test VALUES ($1, $2)", (1, 100))
.unwrap();
let mut tx = db.begin().unwrap();
for row in tx.query("SELECT * FROM test", ()).unwrap() {
let row = row.unwrap();
assert_eq!(row.get::<i64>(0).unwrap(), 1);
assert_eq!(row.get::<i64>(1).unwrap(), 100);
}
tx.commit().unwrap();
}
#[test]
fn test_transaction_query_one() {
let db = Database::open_in_memory().unwrap();
db.execute(
"CREATE TABLE test (id INTEGER PRIMARY KEY, value INTEGER)",
(),
)
.unwrap();
db.execute("INSERT INTO test VALUES ($1, $2)", (1, 100))
.unwrap();
let mut tx = db.begin().unwrap();
let value: i64 = tx
.query_one("SELECT value FROM test WHERE id = $1", (1,))
.unwrap();
assert_eq!(value, 100);
tx.commit().unwrap();
}
#[test]
fn test_committed_transaction_error() {
let db = Database::open_in_memory().unwrap();
db.execute("CREATE TABLE test (id INTEGER PRIMARY KEY)", ())
.unwrap();
let mut tx = db.begin().unwrap();
tx.commit().unwrap();
assert!(tx.execute("INSERT INTO test VALUES ($1)", (1,)).is_err());
assert!(tx.commit().is_err());
}
#[test]
fn test_transaction_id() {
let db = Database::open_in_memory().unwrap();
let tx = db.begin().unwrap();
assert!(tx.id() > 0);
}
}