use std::collections::{HashMap, VecDeque};
use std::path::Path;
use std::sync::{Arc, Mutex, MutexGuard};
use crate::sql::dialect::SqlriteDialect;
use sqlparser::ast::Statement as AstStatement;
use sqlparser::parser::Parser;
use crate::error::{Result, SQLRiteError};
use crate::mvcc::{
ConcurrentTx, JournalMode, MvccCommitBatch, MvccLogRecord, RowID, RowVersion, VersionPayload,
};
use crate::sql::db::database::{Database, TxnSnapshot};
use crate::sql::db::table::{Table, Value};
use crate::sql::executor::execute_select_rows;
use crate::sql::pager::{self, AccessMode, open_database_with_mode, save_database};
use crate::sql::params::{rewrite_placeholders, substitute_params};
use crate::sql::parser::select::SelectQuery;
use crate::sql::process_ast_with_render;
const DEFAULT_PREP_CACHE_CAP: usize = 16;
pub struct Connection {
inner: Arc<Mutex<Database>>,
prep_cache: VecDeque<(String, Arc<CachedPlan>)>,
prep_cache_cap: usize,
concurrent_tx: Mutex<Option<ConcurrentTx>>,
}
impl Connection {
pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
let path = path.as_ref();
let db_name = path
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("db")
.to_string();
let db = if path.exists() {
open_database_with_mode(path, db_name, AccessMode::ReadWrite)?
} else {
let mut fresh = Database::new(db_name);
fresh.source_path = Some(path.to_path_buf());
save_database(&mut fresh, path)?;
fresh
};
Ok(Self::wrap(db))
}
pub fn open_read_only<P: AsRef<Path>>(path: P) -> Result<Self> {
let path = path.as_ref();
let db_name = path
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("db")
.to_string();
let db = open_database_with_mode(path, db_name, AccessMode::ReadOnly)?;
Ok(Self::wrap(db))
}
pub fn open_in_memory() -> Result<Self> {
Ok(Self::wrap(Database::new("memdb".to_string())))
}
fn wrap(db: Database) -> Self {
Self {
inner: Arc::new(Mutex::new(db)),
prep_cache: VecDeque::new(),
prep_cache_cap: DEFAULT_PREP_CACHE_CAP,
concurrent_tx: Mutex::new(None),
}
}
pub fn connect(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
prep_cache: VecDeque::new(),
prep_cache_cap: self.prep_cache_cap,
concurrent_tx: Mutex::new(None),
}
}
pub fn handle_count(&self) -> usize {
Arc::strong_count(&self.inner)
}
fn lock(&self) -> MutexGuard<'_, Database> {
self.inner
.lock()
.unwrap_or_else(|e| panic!("sqlrite: database mutex poisoned: {e}"))
}
pub fn execute(&mut self, sql: &str) -> Result<String> {
let intent = concurrent_tx_intent(sql);
let has_tx = self.concurrent_tx_is_open();
match intent {
ConcurrentTxIntent::Begin => self.begin_concurrent(),
ConcurrentTxIntent::Commit if has_tx => self.commit_concurrent(),
ConcurrentTxIntent::Rollback if has_tx => self.rollback_concurrent(),
ConcurrentTxIntent::None
| ConcurrentTxIntent::Commit
| ConcurrentTxIntent::Rollback => self.execute_dispatch(sql),
}
}
pub fn execute_with_render(&mut self, sql: &str) -> Result<crate::sql::CommandOutput> {
let intent = concurrent_tx_intent(sql);
let has_tx = self.concurrent_tx_is_open();
let status = match intent {
ConcurrentTxIntent::Begin => self.begin_concurrent()?,
ConcurrentTxIntent::Commit if has_tx => self.commit_concurrent()?,
ConcurrentTxIntent::Rollback if has_tx => self.rollback_concurrent()?,
ConcurrentTxIntent::None
| ConcurrentTxIntent::Commit
| ConcurrentTxIntent::Rollback => return self.execute_dispatch_with_render(sql),
};
Ok(crate::sql::CommandOutput {
status,
rendered: None,
})
}
pub fn concurrent_tx_is_open(&self) -> bool {
self.lock_concurrent_tx().is_some()
}
fn lock_concurrent_tx(&self) -> MutexGuard<'_, Option<ConcurrentTx>> {
self.concurrent_tx.lock().unwrap_or_else(|e| {
panic!("sqlrite: concurrent_tx mutex poisoned: {e}");
})
}
pub(crate) fn with_snapshot_read<F, R>(&self, f: F) -> R
where
F: FnOnce(&Database) -> R,
{
let mut tx_slot = self.lock_concurrent_tx();
let mut db = self.lock();
match tx_slot.as_mut() {
None => f(&db),
Some(tx) => {
std::mem::swap(&mut db.tables, &mut tx.tables);
let prior_txn = db.txn.take();
db.txn = Some(TxnSnapshot {
tables: HashMap::new(),
});
struct UnswapGuard<'a> {
db: &'a mut Database,
tx_tables: &'a mut HashMap<String, Table>,
prior_txn: Option<TxnSnapshot>,
armed: bool,
}
impl Drop for UnswapGuard<'_> {
fn drop(&mut self) {
if self.armed {
self.db.txn = self.prior_txn.take();
std::mem::swap(&mut self.db.tables, self.tx_tables);
}
}
}
let mut guard = UnswapGuard {
db: &mut db,
tx_tables: &mut tx.tables,
prior_txn,
armed: true,
};
let result = f(guard.db);
guard.armed = false;
guard.db.txn = guard.prior_txn.take();
std::mem::swap(&mut guard.db.tables, guard.tx_tables);
result
}
}
}
fn execute_dispatch(&mut self, sql: &str) -> Result<String> {
if self.concurrent_tx_is_open() {
self.execute_in_concurrent_tx(sql)
} else {
let mut db = self.lock();
crate::sql::process_command(sql, &mut db)
}
}
fn execute_dispatch_with_render(&mut self, sql: &str) -> Result<crate::sql::CommandOutput> {
if self.concurrent_tx_is_open() {
self.execute_in_concurrent_tx_with_render(sql)
} else {
let mut db = self.lock();
crate::sql::process_command_with_render(sql, &mut db)
}
}
fn begin_concurrent(&mut self) -> Result<String> {
let mut tx_slot = self.lock_concurrent_tx();
if tx_slot.is_some() {
return Err(SQLRiteError::General(
"cannot BEGIN CONCURRENT: a concurrent transaction is already open".to_string(),
));
}
let db = self.lock();
if db.journal_mode() != JournalMode::Mvcc {
return Err(SQLRiteError::General(
"BEGIN CONCURRENT requires `PRAGMA journal_mode = mvcc;` first".to_string(),
));
}
if db.in_transaction() {
return Err(SQLRiteError::General(
"cannot BEGIN CONCURRENT: a non-concurrent transaction is already open".to_string(),
));
}
if db.is_read_only() {
return Err(SQLRiteError::General(
"cannot BEGIN CONCURRENT: database is opened read-only".to_string(),
));
}
let tx = ConcurrentTx::begin(db.mvcc_clock(), db.mv_store().active_registry(), &db.tables);
drop(db);
*tx_slot = Some(tx);
Ok("BEGIN".to_string())
}
fn commit_concurrent(&mut self) -> Result<String> {
let mut tx_slot = self.lock_concurrent_tx();
let tx = tx_slot
.take()
.expect("commit_concurrent called without active tx (caller should check)");
drop(tx_slot);
let mut db = self.lock();
if !tx.schema_unchanged(&db.tables) {
return Err(SQLRiteError::Busy(
"schema changed under BEGIN CONCURRENT (a CREATE/DROP/ALTER ran on \
another connection); transaction rolled back"
.to_string(),
));
}
let writes = diff_tables_for_writes(&tx.tables_at_begin, &tx.tables)?;
let mv = db.mv_store().clone();
let begin_ts = tx.begin_ts();
for (row_id, _payload) in &writes {
if let Some(latest_begin) = mv.latest_committed_begin(row_id) {
if latest_begin > begin_ts {
return Err(SQLRiteError::Busy(format!(
"write-write conflict on {}/{}: another transaction committed \
this row at ts={latest_begin} (after our begin_ts={begin_ts}); \
transaction rolled back, retry with a fresh BEGIN CONCURRENT",
row_id.table, row_id.rowid,
)));
}
}
}
let commit_ts = db.mvcc_clock().tick();
for (row_id, payload) in &writes {
let version = RowVersion::committed(commit_ts, payload.clone());
mv.push_committed(row_id.clone(), version)
.map_err(|e| SQLRiteError::General(format!("MvStore push failed: {e}")))?;
}
apply_writes_to_live(&mut db, &tx.tables, &writes)?;
if let Some(pager) = db.pager.as_mut() {
let records = writes
.iter()
.map(|(row, payload)| MvccLogRecord {
row: row.clone(),
payload: payload.clone(),
})
.collect();
let batch = MvccCommitBatch { commit_ts, records };
if let Err(append_err) = pager.append_mvcc_batch(&batch) {
return Err(SQLRiteError::General(format!(
"COMMIT failed appending MVCC log record: {append_err}"
)));
}
if let Err(set_err) = pager.observe_clock_high_water(commit_ts) {
return Err(SQLRiteError::General(format!(
"COMMIT failed updating WAL clock high-water: {set_err}"
)));
}
}
if let Some(path) = db.source_path.clone() {
if let Err(save_err) = pager::save_database(&mut db, &path) {
return Err(SQLRiteError::General(format!(
"COMMIT failed during save_database: {save_err}"
)));
}
}
drop(tx);
let watermark = mv.active_watermark();
for (row_id, _) in &writes {
mv.gc_chain(row_id, watermark);
}
Ok("COMMIT".to_string())
}
fn rollback_concurrent(&mut self) -> Result<String> {
let _ = self
.lock_concurrent_tx()
.take()
.expect("rollback_concurrent called without active tx (caller should check)");
Ok("ROLLBACK".to_string())
}
fn execute_in_concurrent_tx(&mut self, sql: &str) -> Result<String> {
self.execute_in_concurrent_tx_with_render(sql)
.map(|o| o.status)
}
fn execute_in_concurrent_tx_with_render(
&mut self,
sql: &str,
) -> Result<crate::sql::CommandOutput> {
let intent = legacy_tx_intent(sql);
if matches!(intent, LegacyTxIntent::Begin) {
return Err(SQLRiteError::General(
"cannot BEGIN: a concurrent transaction is already open".to_string(),
));
}
if rejects_in_concurrent_tx(sql) {
return Err(SQLRiteError::General(
"DDL is not supported inside BEGIN CONCURRENT (v0 limitation; the \
transaction stays open, the live schema is unchanged)"
.to_string(),
));
}
let mut tx_slot = self.lock_concurrent_tx();
let tx = tx_slot
.as_mut()
.expect("execute_in_concurrent_tx called without active tx");
let mut db = self.inner.lock().unwrap_or_else(|e| {
panic!("sqlrite: database mutex poisoned: {e}");
});
std::mem::swap(&mut db.tables, &mut tx.tables);
let prior_txn = db.txn.take();
db.txn = Some(TxnSnapshot {
tables: HashMap::new(),
});
let result = crate::sql::process_command_with_render(sql, &mut db);
db.txn = prior_txn;
std::mem::swap(&mut db.tables, &mut tx.tables);
result
}
pub fn prepare<'c>(&'c mut self, sql: &str) -> Result<Statement<'c>> {
let plan = Arc::new(CachedPlan::compile(sql)?);
Ok(Statement { conn: self, plan })
}
pub fn prepare_cached<'c>(&'c mut self, sql: &str) -> Result<Statement<'c>> {
let plan = if let Some(pos) = self.prep_cache.iter().position(|(k, _)| k == sql) {
let (k, v) = self.prep_cache.remove(pos).unwrap();
self.prep_cache.push_back((k, Arc::clone(&v)));
v
} else {
let plan = Arc::new(CachedPlan::compile(sql)?);
self.prep_cache
.push_back((sql.to_string(), Arc::clone(&plan)));
while self.prep_cache.len() > self.prep_cache_cap {
self.prep_cache.pop_front();
}
plan
};
Ok(Statement { conn: self, plan })
}
pub fn set_prepared_cache_capacity(&mut self, cap: usize) {
self.prep_cache_cap = cap;
while self.prep_cache.len() > cap {
self.prep_cache.pop_front();
}
}
pub fn prepared_cache_len(&self) -> usize {
self.prep_cache.len()
}
pub fn in_transaction(&self) -> bool {
self.lock().in_transaction()
}
pub fn auto_vacuum_threshold(&self) -> Option<f32> {
self.lock().auto_vacuum_threshold()
}
pub fn set_auto_vacuum_threshold(&mut self, threshold: Option<f32>) -> Result<()> {
self.lock().set_auto_vacuum_threshold(threshold)
}
pub fn is_read_only(&self) -> bool {
self.lock().is_read_only()
}
pub fn journal_mode(&self) -> crate::mvcc::JournalMode {
self.lock().journal_mode()
}
pub fn vacuum_mvcc(&self) -> usize {
let db = self.lock();
let mv = db.mv_store().clone();
let watermark = mv.active_watermark();
drop(db);
mv.gc_all(watermark)
}
#[doc(hidden)]
pub fn database(&self) -> MutexGuard<'_, Database> {
self.lock()
}
#[doc(hidden)]
pub fn database_mut(&mut self) -> MutexGuard<'_, Database> {
self.lock()
}
}
impl std::fmt::Debug for Connection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let db = self.lock();
f.debug_struct("Connection")
.field("in_transaction", &db.in_transaction())
.field("read_only", &db.is_read_only())
.field("tables", &db.tables.len())
.field("prep_cache_len", &self.prep_cache.len())
.field("handles", &Arc::strong_count(&self.inner))
.field("concurrent_tx", &self.concurrent_tx_is_open())
.finish()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ConcurrentTxIntent {
Begin,
Commit,
Rollback,
None,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum LegacyTxIntent {
Begin,
None,
}
fn concurrent_tx_intent(sql: &str) -> ConcurrentTxIntent {
let tokens = lowercase_tokens(sql);
let head = tokens.as_slice();
match head {
[first, second, ..] if first == "begin" && second == "concurrent" => {
ConcurrentTxIntent::Begin
}
[first, ..] if first == "commit" => ConcurrentTxIntent::Commit,
[first, ..] if first == "end" => ConcurrentTxIntent::Commit,
[first, ..] if first == "rollback" => ConcurrentTxIntent::Rollback,
_ => ConcurrentTxIntent::None,
}
}
fn legacy_tx_intent(sql: &str) -> LegacyTxIntent {
let tokens = lowercase_tokens(sql);
let head = tokens.as_slice();
match head {
[first, ..] if first == "begin" => {
if matches!(head.get(1).map(String::as_str), Some("concurrent")) {
LegacyTxIntent::None
} else {
LegacyTxIntent::Begin
}
}
[first, ..] if first == "start" => LegacyTxIntent::Begin,
_ => LegacyTxIntent::None,
}
}
fn lowercase_tokens(sql: &str) -> Vec<String> {
sql.split(|c: char| c.is_whitespace() || c == ';' || c == '(' || c == ')' || c == ',')
.filter(|t| !t.is_empty())
.map(|t| t.to_ascii_lowercase())
.collect()
}
fn rejects_in_concurrent_tx(sql: &str) -> bool {
let trimmed = sql.trim_start();
let lower = trimmed.to_ascii_lowercase();
lower.starts_with("create ")
|| lower.starts_with("drop ")
|| lower.starts_with("alter ")
|| lower.starts_with("vacuum")
}
fn diff_tables_for_writes(
live: &HashMap<String, Table>,
snapshot: &HashMap<String, Table>,
) -> Result<Vec<(RowID, VersionPayload)>> {
let mut writes: Vec<(RowID, VersionPayload)> = Vec::new();
for (name, snap_table) in snapshot {
let live_table = live.get(name).ok_or_else(|| {
SQLRiteError::Internal(format!(
"concurrent commit: table '{name}' missing from live database"
))
})?;
let live_rowids: std::collections::HashSet<i64> = live_table.rowids().into_iter().collect();
let snap_rowids = snap_table.rowids();
for rowid in &snap_rowids {
let snap_payload = build_payload(snap_table, *rowid);
if live_rowids.contains(rowid) {
let live_payload = build_payload(live_table, *rowid);
if live_payload != snap_payload {
writes.push((RowID::new(name, *rowid), snap_payload));
}
} else {
writes.push((RowID::new(name, *rowid), snap_payload));
}
}
let snap_set: std::collections::HashSet<i64> = snap_rowids.into_iter().collect();
for rowid in live_table.rowids() {
if !snap_set.contains(&rowid) {
writes.push((RowID::new(name, rowid), VersionPayload::Tombstone));
}
}
}
Ok(writes)
}
fn build_payload(table: &Table, rowid: i64) -> VersionPayload {
let cols = table.column_names();
let vals = table.extract_row(rowid);
let pairs: Vec<(String, Value)> = cols
.into_iter()
.zip(vals)
.map(|(c, v)| (c, v.unwrap_or(Value::Null)))
.collect();
VersionPayload::Present(pairs)
}
fn apply_writes_to_live(
db: &mut Database,
_snapshot: &HashMap<String, Table>,
writes: &[(RowID, VersionPayload)],
) -> Result<()> {
for (row_id, payload) in writes {
let live_table = db.tables.get_mut(&row_id.table).ok_or_else(|| {
SQLRiteError::Internal(format!(
"concurrent commit: table '{}' missing from live database",
row_id.table
))
})?;
live_table.delete_row(row_id.rowid);
if let VersionPayload::Present(cols) = payload {
let values: Vec<Option<Value>> = cols
.iter()
.map(|(_col, value)| match value {
Value::Null => None,
other => Some(other.clone()),
})
.collect();
live_table.restore_row(row_id.rowid, values).map_err(|e| {
SQLRiteError::Internal(format!(
"concurrent commit: restore_row({}) on table '{}' failed: {e}",
row_id.rowid, row_id.table,
))
})?;
}
}
Ok(())
}
#[derive(Debug)]
struct CachedPlan {
#[allow(dead_code)]
sql: String,
ast: AstStatement,
param_count: usize,
select: Option<SelectQuery>,
}
impl CachedPlan {
fn compile(sql: &str) -> Result<Self> {
let dialect = SqlriteDialect::new();
let mut ast = Parser::parse_sql(&dialect, sql).map_err(SQLRiteError::from)?;
let Some(mut stmt) = ast.pop() else {
return Err(SQLRiteError::General("no statement to prepare".to_string()));
};
if !ast.is_empty() {
return Err(SQLRiteError::General(
"prepare() accepts a single statement; found more than one".to_string(),
));
}
let param_count = rewrite_placeholders(&mut stmt);
let select = match &stmt {
AstStatement::Query(_) => Some(SelectQuery::new(&stmt)?),
_ => None,
};
Ok(Self {
sql: sql.to_string(),
ast: stmt,
param_count,
select,
})
}
}
pub struct Statement<'c> {
conn: &'c mut Connection,
plan: Arc<CachedPlan>,
}
impl std::fmt::Debug for Statement<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Statement")
.field("sql", &self.plan.sql)
.field("param_count", &self.plan.param_count)
.field(
"kind",
&match self.plan.select {
Some(_) => "Select",
None => "Other",
},
)
.finish()
}
}
impl<'c> Statement<'c> {
pub fn parameter_count(&self) -> usize {
self.plan.param_count
}
pub fn run(&mut self) -> Result<String> {
if self.plan.param_count > 0 {
return Err(SQLRiteError::General(format!(
"statement has {} `?` placeholder(s); call execute_with_params()",
self.plan.param_count
)));
}
let ast = self.plan.ast.clone();
let mut db = self.conn.lock();
process_ast_with_render(ast, &mut db).map(|o| o.status)
}
pub fn execute_with_params(&mut self, params: &[Value]) -> Result<String> {
self.check_arity(params)?;
let mut ast = self.plan.ast.clone();
if !params.is_empty() {
substitute_params(&mut ast, params)?;
}
let mut db = self.conn.lock();
process_ast_with_render(ast, &mut db).map(|o| o.status)
}
pub fn query(&self) -> Result<Rows> {
if self.plan.param_count > 0 {
return Err(SQLRiteError::General(format!(
"statement has {} `?` placeholder(s); call query_with_params()",
self.plan.param_count
)));
}
let Some(sq) = self.plan.select.as_ref() else {
return Err(SQLRiteError::General(
"query() only works on SELECT statements; use run() for DDL/DML".to_string(),
));
};
let result = self
.conn
.with_snapshot_read(|db| execute_select_rows(sq.clone(), db))?;
Ok(Rows {
columns: result.columns,
rows: result.rows.into_iter(),
})
}
pub fn query_with_params(&self, params: &[Value]) -> Result<Rows> {
self.check_arity(params)?;
if self.plan.select.is_none() {
return Err(SQLRiteError::General(
"query_with_params() only works on SELECT statements; use execute_with_params() \
for DDL/DML"
.to_string(),
));
}
let mut ast = self.plan.ast.clone();
if !params.is_empty() {
substitute_params(&mut ast, params)?;
}
let sq = SelectQuery::new(&ast)?;
let result = self
.conn
.with_snapshot_read(|db| execute_select_rows(sq, db))?;
Ok(Rows {
columns: result.columns,
rows: result.rows.into_iter(),
})
}
fn check_arity(&self, params: &[Value]) -> Result<()> {
if params.len() != self.plan.param_count {
return Err(SQLRiteError::General(format!(
"expected {} parameter{}, got {}",
self.plan.param_count,
if self.plan.param_count == 1 { "" } else { "s" },
params.len()
)));
}
Ok(())
}
pub fn column_names(&self) -> Option<Vec<String>> {
match &self.plan.select {
Some(_) => {
None
}
None => None,
}
}
}
pub struct Rows {
columns: Vec<String>,
rows: std::vec::IntoIter<Vec<Value>>,
}
impl std::fmt::Debug for Rows {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Rows")
.field("columns", &self.columns)
.field("remaining", &self.rows.len())
.finish()
}
}
impl Rows {
pub fn columns(&self) -> &[String] {
&self.columns
}
pub fn next(&mut self) -> Result<Option<Row<'_>>> {
Ok(self.rows.next().map(|values| Row {
columns: &self.columns,
values,
}))
}
pub fn collect_all(mut self) -> Result<Vec<OwnedRow>> {
let mut out = Vec::new();
while let Some(r) = self.next()? {
out.push(r.to_owned_row());
}
Ok(out)
}
}
pub struct Row<'r> {
columns: &'r [String],
values: Vec<Value>,
}
impl<'r> Row<'r> {
pub fn get<T: FromValue>(&self, idx: usize) -> Result<T> {
let v = self.values.get(idx).ok_or_else(|| {
SQLRiteError::General(format!(
"column index {idx} out of bounds (row has {} columns)",
self.values.len()
))
})?;
T::from_value(v)
}
pub fn get_by_name<T: FromValue>(&self, name: &str) -> Result<T> {
let idx = self
.columns
.iter()
.position(|c| c == name)
.ok_or_else(|| SQLRiteError::General(format!("no column named '{name}' in row")))?;
self.get(idx)
}
pub fn columns(&self) -> &[String] {
self.columns
}
pub fn to_owned_row(&self) -> OwnedRow {
OwnedRow {
columns: self.columns.to_vec(),
values: self.values.clone(),
}
}
}
#[derive(Debug, Clone)]
pub struct OwnedRow {
pub columns: Vec<String>,
pub values: Vec<Value>,
}
impl OwnedRow {
pub fn get<T: FromValue>(&self, idx: usize) -> Result<T> {
let v = self.values.get(idx).ok_or_else(|| {
SQLRiteError::General(format!(
"column index {idx} out of bounds (row has {} columns)",
self.values.len()
))
})?;
T::from_value(v)
}
pub fn get_by_name<T: FromValue>(&self, name: &str) -> Result<T> {
let idx = self
.columns
.iter()
.position(|c| c == name)
.ok_or_else(|| SQLRiteError::General(format!("no column named '{name}' in row")))?;
self.get(idx)
}
}
pub trait FromValue: Sized {
fn from_value(v: &Value) -> Result<Self>;
}
impl FromValue for i64 {
fn from_value(v: &Value) -> Result<Self> {
match v {
Value::Integer(n) => Ok(*n),
Value::Null => Err(SQLRiteError::General(
"expected Integer, got NULL".to_string(),
)),
other => Err(SQLRiteError::General(format!(
"cannot convert {other:?} to i64"
))),
}
}
}
impl FromValue for f64 {
fn from_value(v: &Value) -> Result<Self> {
match v {
Value::Real(f) => Ok(*f),
Value::Integer(n) => Ok(*n as f64),
Value::Null => Err(SQLRiteError::General("expected Real, got NULL".to_string())),
other => Err(SQLRiteError::General(format!(
"cannot convert {other:?} to f64"
))),
}
}
}
impl FromValue for String {
fn from_value(v: &Value) -> Result<Self> {
match v {
Value::Text(s) => Ok(s.clone()),
Value::Null => Err(SQLRiteError::General("expected Text, got NULL".to_string())),
other => Err(SQLRiteError::General(format!(
"cannot convert {other:?} to String"
))),
}
}
}
impl FromValue for bool {
fn from_value(v: &Value) -> Result<Self> {
match v {
Value::Bool(b) => Ok(*b),
Value::Integer(n) => Ok(*n != 0),
Value::Null => Err(SQLRiteError::General("expected Bool, got NULL".to_string())),
other => Err(SQLRiteError::General(format!(
"cannot convert {other:?} to bool"
))),
}
}
}
impl<T: FromValue> FromValue for Option<T> {
fn from_value(v: &Value) -> Result<Self> {
match v {
Value::Null => Ok(None),
other => Ok(Some(T::from_value(other)?)),
}
}
}
impl FromValue for Value {
fn from_value(v: &Value) -> Result<Self> {
Ok(v.clone())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn tmp_path(name: &str) -> std::path::PathBuf {
let mut p = std::env::temp_dir();
let pid = std::process::id();
let nanos = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos())
.unwrap_or(0);
p.push(format!("sqlrite-conn-{pid}-{nanos}-{name}.sqlrite"));
p
}
fn cleanup(path: &std::path::Path) {
let _ = std::fs::remove_file(path);
let mut wal = path.as_os_str().to_owned();
wal.push("-wal");
let _ = std::fs::remove_file(std::path::PathBuf::from(wal));
}
#[test]
fn in_memory_roundtrip() {
let mut conn = Connection::open_in_memory().unwrap();
conn.execute("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT, age INTEGER);")
.unwrap();
conn.execute("INSERT INTO users (name, age) VALUES ('alice', 30);")
.unwrap();
conn.execute("INSERT INTO users (name, age) VALUES ('bob', 25);")
.unwrap();
let stmt = conn.prepare("SELECT id, name, age FROM users;").unwrap();
let mut rows = stmt.query().unwrap();
assert_eq!(rows.columns(), &["id", "name", "age"]);
let mut collected: Vec<(i64, String, i64)> = Vec::new();
while let Some(row) = rows.next().unwrap() {
collected.push((
row.get::<i64>(0).unwrap(),
row.get::<String>(1).unwrap(),
row.get::<i64>(2).unwrap(),
));
}
assert_eq!(collected.len(), 2);
assert!(collected.iter().any(|(_, n, a)| n == "alice" && *a == 30));
assert!(collected.iter().any(|(_, n, a)| n == "bob" && *a == 25));
}
#[test]
fn file_backed_persists_across_connections() {
let path = tmp_path("persist");
{
let mut c1 = Connection::open(&path).unwrap();
c1.execute("CREATE TABLE items (id INTEGER PRIMARY KEY, label TEXT);")
.unwrap();
c1.execute("INSERT INTO items (label) VALUES ('one');")
.unwrap();
}
{
let mut c2 = Connection::open(&path).unwrap();
let stmt = c2.prepare("SELECT label FROM items;").unwrap();
let mut rows = stmt.query().unwrap();
let first = rows.next().unwrap().expect("one row");
assert_eq!(first.get::<String>(0).unwrap(), "one");
assert!(rows.next().unwrap().is_none());
}
cleanup(&path);
}
#[test]
fn read_only_connection_rejects_writes() {
let path = tmp_path("ro_reject");
{
let mut c = Connection::open(&path).unwrap();
c.execute("CREATE TABLE t (id INTEGER PRIMARY KEY);")
.unwrap();
c.execute("INSERT INTO t (id) VALUES (1);").unwrap();
}
let mut ro = Connection::open_read_only(&path).unwrap();
assert!(ro.is_read_only());
let err = ro.execute("INSERT INTO t (id) VALUES (2);").unwrap_err();
assert!(format!("{err}").contains("read-only"));
cleanup(&path);
}
#[test]
fn transactions_work_through_connection() {
let mut conn = Connection::open_in_memory().unwrap();
conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY, x INTEGER);")
.unwrap();
conn.execute("INSERT INTO t (x) VALUES (1);").unwrap();
conn.execute("BEGIN;").unwrap();
assert!(conn.in_transaction());
conn.execute("INSERT INTO t (x) VALUES (2);").unwrap();
conn.execute("ROLLBACK;").unwrap();
assert!(!conn.in_transaction());
let stmt = conn.prepare("SELECT x FROM t;").unwrap();
let rows = stmt.query().unwrap().collect_all().unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get::<i64>(0).unwrap(), 1);
}
#[test]
fn get_by_name_works() {
let mut conn = Connection::open_in_memory().unwrap();
conn.execute("CREATE TABLE t (a INTEGER, b TEXT);").unwrap();
conn.execute("INSERT INTO t (a, b) VALUES (42, 'hello');")
.unwrap();
let stmt = conn.prepare("SELECT a, b FROM t;").unwrap();
let mut rows = stmt.query().unwrap();
let row = rows.next().unwrap().unwrap();
assert_eq!(row.get_by_name::<i64>("a").unwrap(), 42);
assert_eq!(row.get_by_name::<String>("b").unwrap(), "hello");
}
#[test]
fn null_column_maps_to_none() {
let mut conn = Connection::open_in_memory().unwrap();
conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY, note TEXT);")
.unwrap();
conn.execute("INSERT INTO t (id) VALUES (1);").unwrap();
let stmt = conn.prepare("SELECT id, note FROM t;").unwrap();
let mut rows = stmt.query().unwrap();
let row = rows.next().unwrap().unwrap();
assert_eq!(row.get::<i64>(0).unwrap(), 1);
assert_eq!(row.get::<Option<String>>(1).unwrap(), None);
}
#[test]
fn prepare_rejects_multiple_statements() {
let mut conn = Connection::open_in_memory().unwrap();
let err = conn.prepare("SELECT 1; SELECT 2;").unwrap_err();
assert!(format!("{err}").contains("single statement"));
}
#[test]
fn query_on_non_select_errors() {
let mut conn = Connection::open_in_memory().unwrap();
conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY);")
.unwrap();
let stmt = conn.prepare("INSERT INTO t VALUES (1);").unwrap();
let err = stmt.query().unwrap_err();
assert!(format!("{err}").contains("SELECT"));
}
#[test]
fn auto_vacuum_threshold_default_and_setter() {
let mut conn = Connection::open_in_memory().unwrap();
assert_eq!(
conn.auto_vacuum_threshold(),
Some(0.25),
"fresh connection should ship with the SQLite-parity default"
);
conn.set_auto_vacuum_threshold(None).unwrap();
assert_eq!(conn.auto_vacuum_threshold(), None);
conn.set_auto_vacuum_threshold(Some(0.5)).unwrap();
assert_eq!(conn.auto_vacuum_threshold(), Some(0.5));
let err = conn.set_auto_vacuum_threshold(Some(1.5)).unwrap_err();
assert!(
format!("{err}").contains("auto_vacuum_threshold"),
"expected typed range error, got: {err}"
);
assert_eq!(
conn.auto_vacuum_threshold(),
Some(0.5),
"rejected setter call must not mutate the threshold"
);
}
#[test]
fn index_out_of_bounds_errors_cleanly() {
let mut conn = Connection::open_in_memory().unwrap();
conn.execute("CREATE TABLE t (a INTEGER PRIMARY KEY);")
.unwrap();
conn.execute("INSERT INTO t (a) VALUES (1);").unwrap();
let stmt = conn.prepare("SELECT a FROM t;").unwrap();
let mut rows = stmt.query().unwrap();
let row = rows.next().unwrap().unwrap();
let err = row.get::<i64>(99).unwrap_err();
assert!(format!("{err}").contains("out of bounds"));
}
#[test]
fn parameter_count_reflects_question_marks() {
let mut conn = Connection::open_in_memory().unwrap();
conn.execute("CREATE TABLE t (a INTEGER, b TEXT);").unwrap();
let stmt = conn.prepare("SELECT a, b FROM t WHERE a = ?").unwrap();
assert_eq!(stmt.parameter_count(), 1);
let stmt = conn
.prepare("SELECT a, b FROM t WHERE a = ? AND b = ?")
.unwrap();
assert_eq!(stmt.parameter_count(), 2);
let stmt = conn.prepare("SELECT a FROM t").unwrap();
assert_eq!(stmt.parameter_count(), 0);
}
#[test]
fn query_with_params_binds_scalars() {
let mut conn = Connection::open_in_memory().unwrap();
conn.execute("CREATE TABLE t (a INTEGER PRIMARY KEY, b TEXT);")
.unwrap();
conn.execute("INSERT INTO t (a, b) VALUES (1, 'alice');")
.unwrap();
conn.execute("INSERT INTO t (a, b) VALUES (2, 'bob');")
.unwrap();
conn.execute("INSERT INTO t (a, b) VALUES (3, 'carol');")
.unwrap();
let stmt = conn.prepare("SELECT b FROM t WHERE a = ?").unwrap();
let rows = stmt
.query_with_params(&[Value::Integer(2)])
.unwrap()
.collect_all()
.unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get::<String>(0).unwrap(), "bob");
}
#[test]
fn execute_with_params_binds_insert_values() {
let mut conn = Connection::open_in_memory().unwrap();
conn.execute("CREATE TABLE t (a INTEGER, b TEXT);").unwrap();
let mut stmt = conn.prepare("INSERT INTO t (a, b) VALUES (?, ?)").unwrap();
stmt.execute_with_params(&[Value::Integer(7), Value::Text("hi".into())])
.unwrap();
stmt.execute_with_params(&[Value::Integer(8), Value::Text("yo".into())])
.unwrap();
let stmt = conn.prepare("SELECT a, b FROM t").unwrap();
let rows = stmt.query().unwrap().collect_all().unwrap();
assert_eq!(rows.len(), 2);
assert!(
rows.iter()
.any(|r| r.get::<i64>(0).unwrap() == 7 && r.get::<String>(1).unwrap() == "hi")
);
assert!(
rows.iter()
.any(|r| r.get::<i64>(0).unwrap() == 8 && r.get::<String>(1).unwrap() == "yo")
);
}
#[test]
fn arity_mismatch_returns_clean_error() {
let mut conn = Connection::open_in_memory().unwrap();
conn.execute("CREATE TABLE t (a INTEGER, b TEXT);").unwrap();
let stmt = conn
.prepare("SELECT * FROM t WHERE a = ? AND b = ?")
.unwrap();
let err = stmt.query_with_params(&[Value::Integer(1)]).unwrap_err();
assert!(format!("{err}").contains("expected 2 parameter"));
}
#[test]
fn run_and_query_reject_when_placeholders_present() {
let mut conn = Connection::open_in_memory().unwrap();
conn.execute("CREATE TABLE t (a INTEGER);").unwrap();
let mut stmt_select = conn.prepare("SELECT a FROM t WHERE a = ?").unwrap();
let err = stmt_select.query().unwrap_err();
assert!(format!("{err}").contains("query_with_params"));
let err = stmt_select.run().unwrap_err();
assert!(format!("{err}").contains("execute_with_params"));
}
#[test]
fn null_param_compares_against_null() {
let mut conn = Connection::open_in_memory().unwrap();
conn.execute("CREATE TABLE t (a INTEGER);").unwrap();
conn.execute("INSERT INTO t (a) VALUES (1);").unwrap();
let stmt = conn.prepare("SELECT a FROM t WHERE a = ?").unwrap();
let rows = stmt
.query_with_params(&[Value::Null])
.unwrap()
.collect_all()
.unwrap();
assert_eq!(rows.len(), 0);
}
#[test]
fn vector_param_substitutes_through_select() {
let mut conn = Connection::open_in_memory().unwrap();
conn.execute("CREATE TABLE v (id INTEGER PRIMARY KEY, e VECTOR(3));")
.unwrap();
conn.execute("INSERT INTO v (id, e) VALUES (1, [1.0, 0.0, 0.0]);")
.unwrap();
conn.execute("INSERT INTO v (id, e) VALUES (2, [0.0, 1.0, 0.0]);")
.unwrap();
conn.execute("INSERT INTO v (id, e) VALUES (3, [0.0, 0.0, 1.0]);")
.unwrap();
let stmt = conn
.prepare("SELECT id FROM v ORDER BY vec_distance_l2(e, ?) ASC LIMIT 1")
.unwrap();
let rows = stmt
.query_with_params(&[Value::Vector(vec![1.0, 0.0, 0.0])])
.unwrap()
.collect_all()
.unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get::<i64>(0).unwrap(), 1);
}
#[test]
fn prepare_cached_reuses_plans() {
let mut conn = Connection::open_in_memory().unwrap();
conn.execute("CREATE TABLE t (a INTEGER);").unwrap();
for n in 1..=3 {
conn.execute(&format!("INSERT INTO t (a) VALUES ({n});"))
.unwrap();
}
let _ = conn.prepare_cached("SELECT a FROM t WHERE a = ?").unwrap();
let _ = conn.prepare_cached("SELECT a FROM t WHERE a = ?").unwrap();
assert_eq!(conn.prepared_cache_len(), 1);
let _ = conn.prepare_cached("SELECT a FROM t").unwrap();
assert_eq!(conn.prepared_cache_len(), 2);
}
#[test]
fn prepare_cached_evicts_when_over_capacity() {
let mut conn = Connection::open_in_memory().unwrap();
conn.execute("CREATE TABLE t (a INTEGER);").unwrap();
conn.set_prepared_cache_capacity(2);
let _ = conn.prepare_cached("SELECT a FROM t").unwrap();
let _ = conn.prepare_cached("SELECT a FROM t WHERE a = ?").unwrap();
assert_eq!(conn.prepared_cache_len(), 2);
let _ = conn.prepare_cached("SELECT a FROM t WHERE a > ?").unwrap();
assert_eq!(conn.prepared_cache_len(), 2);
}
#[test]
fn vector_bind_through_hnsw_optimizer() {
let mut conn = Connection::open_in_memory().unwrap();
conn.execute("CREATE TABLE v (id INTEGER PRIMARY KEY, e VECTOR(4));")
.unwrap();
let corpus: [(i64, [f32; 4]); 5] = [
(1, [1.0, 0.0, 0.0, 0.0]),
(2, [0.0, 1.0, 0.0, 0.0]),
(3, [0.0, 0.0, 1.0, 0.0]),
(4, [0.0, 0.0, 0.0, 1.0]),
(5, [0.5, 0.5, 0.5, 0.5]),
];
for (id, vec) in corpus {
conn.execute(&format!(
"INSERT INTO v (id, e) VALUES ({id}, [{}, {}, {}, {}]);",
vec[0], vec[1], vec[2], vec[3]
))
.unwrap();
}
conn.execute("CREATE INDEX v_hnsw ON v USING hnsw (e);")
.unwrap();
let stmt = conn
.prepare("SELECT id FROM v ORDER BY vec_distance_l2(e, ?) ASC LIMIT 1")
.unwrap();
let rows = stmt
.query_with_params(&[Value::Vector(vec![0.0, 0.0, 1.0, 0.0])])
.unwrap()
.collect_all()
.unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get::<i64>(0).unwrap(), 3);
let rows = stmt
.query_with_params(&[Value::Vector(vec![1.0, 0.0, 0.0, 0.0])])
.unwrap()
.collect_all()
.unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get::<i64>(0).unwrap(), 1);
}
#[test]
fn cosine_self_query_through_hnsw_optimizer() {
let mut conn = Connection::open_in_memory().unwrap();
conn.execute("CREATE TABLE v (id INTEGER PRIMARY KEY, e VECTOR(4));")
.unwrap();
let corpus: [(i64, [f32; 4]); 5] = [
(1, [1.0, 0.0, 0.0, 0.0]),
(2, [0.0, 1.0, 0.0, 0.0]),
(3, [0.0, 0.0, 1.0, 0.0]),
(4, [0.0, 0.0, 0.0, 1.0]),
(5, [0.5, 0.5, 0.5, 0.5]),
];
for (id, vec) in corpus {
conn.execute(&format!(
"INSERT INTO v (id, e) VALUES ({id}, [{}, {}, {}, {}]);",
vec[0], vec[1], vec[2], vec[3]
))
.unwrap();
}
conn.execute("CREATE INDEX v_hnsw ON v USING hnsw (e) WITH (metric = 'cosine');")
.unwrap();
let rows = conn
.prepare("SELECT id FROM v ORDER BY vec_distance_cosine(e, [0.0, 1.0, 0.0, 0.0]) ASC LIMIT 1")
.unwrap()
.query_with_params(&[])
.unwrap()
.collect_all()
.unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get::<i64>(0).unwrap(), 2);
}
#[test]
fn dot_self_query_through_hnsw_optimizer() {
let mut conn = Connection::open_in_memory().unwrap();
conn.execute("CREATE TABLE v (id INTEGER PRIMARY KEY, e VECTOR(3));")
.unwrap();
let corpus: [(i64, [f32; 3]); 4] = [
(1, [1.0, 0.0, 0.0]),
(2, [2.0, 0.0, 0.0]),
(3, [0.0, 1.0, 0.0]),
(4, [0.0, 0.0, 1.0]),
];
for (id, vec) in corpus {
conn.execute(&format!(
"INSERT INTO v (id, e) VALUES ({id}, [{}, {}, {}]);",
vec[0], vec[1], vec[2]
))
.unwrap();
}
conn.execute("CREATE INDEX v_hnsw ON v USING hnsw (e) WITH (metric = 'dot');")
.unwrap();
let rows = conn
.prepare("SELECT id FROM v ORDER BY vec_distance_dot(e, [3.0, 0.0, 0.0]) ASC LIMIT 1")
.unwrap()
.query_with_params(&[])
.unwrap()
.collect_all()
.unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get::<i64>(0).unwrap(), 2);
}
#[test]
fn metric_mismatch_falls_back_to_brute_force() {
let mut conn = Connection::open_in_memory().unwrap();
conn.execute("CREATE TABLE v (id INTEGER PRIMARY KEY, e VECTOR(2));")
.unwrap();
let half_sqrt2 = std::f32::consts::FRAC_1_SQRT_2;
let corpus: [(i64, [f32; 2]); 3] = [
(1, [1.0, 0.0]),
(2, [half_sqrt2, half_sqrt2]),
(3, [0.0, 1.0]),
];
for (id, vec) in corpus {
conn.execute(&format!(
"INSERT INTO v (id, e) VALUES ({id}, [{}, {}]);",
vec[0], vec[1]
))
.unwrap();
}
conn.execute("CREATE INDEX v_hnsw_l2 ON v USING hnsw (e);")
.unwrap();
let rows = conn
.prepare("SELECT id FROM v ORDER BY vec_distance_cosine(e, [1.0, 0.0]) ASC LIMIT 1")
.unwrap()
.query_with_params(&[])
.unwrap()
.collect_all()
.unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get::<i64>(0).unwrap(), 1);
}
#[test]
fn unknown_metric_name_is_rejected() {
let mut conn = Connection::open_in_memory().unwrap();
conn.execute("CREATE TABLE v (id INTEGER PRIMARY KEY, e VECTOR(2));")
.unwrap();
let err = conn
.execute("CREATE INDEX bad ON v USING hnsw (e) WITH (metric = 'cosin');")
.unwrap_err();
let msg = format!("{err}");
assert!(msg.contains("unknown HNSW metric"), "got: {msg}");
}
#[test]
fn with_metric_on_btree_is_rejected() {
let mut conn = Connection::open_in_memory().unwrap();
conn.execute("CREATE TABLE t (a INTEGER PRIMARY KEY, b TEXT);")
.unwrap();
let err = conn
.execute("CREATE INDEX bad ON t (b) WITH (metric = 'cosine');")
.unwrap_err();
let msg = format!("{err}");
assert!(msg.contains("doesn't support any options"), "got: {msg}");
}
#[test]
fn connect_shares_underlying_database() {
let mut a = Connection::open_in_memory().unwrap();
let mut b = a.connect();
assert_eq!(a.handle_count(), 2);
a.execute("CREATE TABLE shared (id INTEGER PRIMARY KEY, label TEXT);")
.unwrap();
a.execute("INSERT INTO shared (label) VALUES ('via-a');")
.unwrap();
b.execute("INSERT INTO shared (label) VALUES ('via-b');")
.unwrap();
let stmt = b.prepare("SELECT label FROM shared;").unwrap();
let mut labels: Vec<String> = stmt
.query()
.unwrap()
.collect_all()
.unwrap()
.into_iter()
.map(|r| r.get::<String>(0).unwrap())
.collect();
labels.sort();
assert_eq!(labels, vec!["via-a".to_string(), "via-b".to_string()]);
}
#[test]
fn handle_count_reflects_live_handles() {
let primary = Connection::open_in_memory().unwrap();
assert_eq!(primary.handle_count(), 1);
let s1 = primary.connect();
let s2 = primary.connect();
assert_eq!(primary.handle_count(), 3);
drop(s1);
assert_eq!(primary.handle_count(), 2);
drop(s2);
assert_eq!(primary.handle_count(), 1);
}
#[test]
fn threaded_writers_serialize_cleanly() {
use std::thread;
let primary = Connection::open_in_memory().unwrap();
{
let mut p = primary.connect();
p.execute("CREATE TABLE log (id INTEGER PRIMARY KEY, who TEXT, n INTEGER);")
.unwrap();
}
const THREADS: usize = 8;
const PER_THREAD: usize = 25;
let handles: Vec<_> = (0..THREADS)
.map(|tid| {
let mut conn = primary.connect();
thread::spawn(move || {
for n in 0..PER_THREAD {
let sql = format!("INSERT INTO log (who, n) VALUES ('t{tid}', {n});");
conn.execute(&sql).expect("insert under contention");
}
})
})
.collect();
for h in handles {
h.join().expect("worker panicked");
}
let db = primary.database();
let table = db.get_table("log".to_string()).unwrap();
assert_eq!(
table.rowids().len(),
THREADS * PER_THREAD,
"expected every threaded INSERT to commit",
);
}
#[test]
fn connect_shares_file_backed_database() {
let path = tmp_path("connect_file");
let mut primary = Connection::open(&path).unwrap();
primary
.execute("CREATE TABLE t (id INTEGER PRIMARY KEY, v TEXT);")
.unwrap();
let mut sibling = primary.connect();
sibling.execute("INSERT INTO t (v) VALUES ('hi');").unwrap();
let stmt = primary.prepare("SELECT v FROM t;").unwrap();
let rows = stmt.query().unwrap().collect_all().unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get::<String>(0).unwrap(), "hi");
drop(sibling);
drop(primary);
cleanup(&path);
}
#[test]
fn prep_cache_is_per_handle() {
let mut a = Connection::open_in_memory().unwrap();
a.execute("CREATE TABLE t (a INTEGER);").unwrap();
let mut b = a.connect();
let _ = a.prepare_cached("SELECT a FROM t").unwrap();
let _ = a.prepare_cached("SELECT a FROM t").unwrap();
assert_eq!(a.prepared_cache_len(), 1);
assert_eq!(b.prepared_cache_len(), 0);
let _ = b.prepare_cached("SELECT a FROM t").unwrap();
assert_eq!(b.prepared_cache_len(), 1);
}
#[test]
fn connection_is_send_and_sync() {
fn assert_send<T: Send>() {}
fn assert_sync<T: Sync>() {}
assert_send::<Connection>();
assert_sync::<Connection>();
}
#[test]
fn journal_mode_defaults_to_wal_and_renders_through_pragma() {
let mut conn = Connection::open_in_memory().unwrap();
assert_eq!(conn.journal_mode(), crate::mvcc::JournalMode::Wal);
let status = conn.execute("PRAGMA journal_mode;").unwrap();
assert!(
status.contains("1 row returned"),
"unexpected status: {status}"
);
}
#[test]
fn journal_mode_set_to_mvcc_propagates_to_siblings() {
let mut primary = Connection::open_in_memory().unwrap();
let sibling = primary.connect();
assert_eq!(sibling.journal_mode(), crate::mvcc::JournalMode::Wal);
primary.execute("PRAGMA journal_mode = mvcc;").unwrap();
assert_eq!(primary.journal_mode(), crate::mvcc::JournalMode::Mvcc);
assert_eq!(sibling.journal_mode(), crate::mvcc::JournalMode::Mvcc);
primary.execute("PRAGMA journal_mode = wal;").unwrap();
assert_eq!(primary.journal_mode(), crate::mvcc::JournalMode::Wal);
assert_eq!(sibling.journal_mode(), crate::mvcc::JournalMode::Wal);
}
#[test]
fn journal_mode_pragma_is_case_insensitive() {
let mut conn = Connection::open_in_memory().unwrap();
conn.execute("PRAGMA JOURNAL_MODE = MVCC;").unwrap();
assert_eq!(conn.journal_mode(), crate::mvcc::JournalMode::Mvcc);
conn.execute("pragma journal_mode = 'wal';").unwrap();
assert_eq!(conn.journal_mode(), crate::mvcc::JournalMode::Wal);
}
#[test]
fn journal_mode_rejects_unknown_value() {
let mut conn = Connection::open_in_memory().unwrap();
let err = conn
.execute("PRAGMA journal_mode = delete;")
.expect_err("unknown mode must error");
let msg = format!("{err}");
assert!(
msg.contains("unknown mode 'delete'"),
"unexpected error: {msg}"
);
assert_eq!(conn.journal_mode(), crate::mvcc::JournalMode::Wal);
}
#[test]
fn journal_mode_rejects_numeric_value() {
let mut conn = Connection::open_in_memory().unwrap();
let err = conn
.execute("PRAGMA journal_mode = 0;")
.expect_err("numeric mode must error");
let msg = format!("{err}");
assert!(msg.contains("numeric"), "unexpected error: {msg}");
}
#[test]
fn begin_concurrent_requires_mvcc_journal_mode() {
let mut conn = Connection::open_in_memory().unwrap();
let err = conn
.execute("BEGIN CONCURRENT;")
.expect_err("must require MVCC journal mode");
let msg = format!("{err}");
assert!(
msg.contains("PRAGMA journal_mode = mvcc"),
"unexpected error: {msg}"
);
}
#[test]
fn begin_concurrent_then_empty_commit_round_trips() {
let mut conn = Connection::open_in_memory().unwrap();
conn.execute("PRAGMA journal_mode = mvcc;").unwrap();
conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY, v INTEGER);")
.unwrap();
let begin_status = conn.execute("BEGIN CONCURRENT;").unwrap();
assert_eq!(begin_status, "BEGIN");
let commit_status = conn.execute("COMMIT;").unwrap();
assert_eq!(commit_status, "COMMIT");
}
#[test]
fn two_concurrent_inserts_on_disjoint_rows_both_commit() {
let mut a = Connection::open_in_memory().unwrap();
a.execute("PRAGMA journal_mode = mvcc;").unwrap();
a.execute("CREATE TABLE accounts (id INTEGER PRIMARY KEY, balance INTEGER);")
.unwrap();
let mut b = a.connect();
a.execute("BEGIN CONCURRENT;").unwrap();
a.execute("INSERT INTO accounts (id, balance) VALUES (1, 100);")
.unwrap();
b.execute("BEGIN CONCURRENT;").unwrap();
b.execute("INSERT INTO accounts (id, balance) VALUES (2, 200);")
.unwrap();
a.execute("COMMIT;").unwrap();
b.execute("COMMIT;").unwrap();
let stmt = a.prepare("SELECT id, balance FROM accounts;").unwrap();
let mut rows: Vec<(i64, i64)> = stmt
.query()
.unwrap()
.collect_all()
.unwrap()
.into_iter()
.map(|r| (r.get::<i64>(0).unwrap(), r.get::<i64>(1).unwrap()))
.collect();
rows.sort();
assert_eq!(rows, vec![(1, 100), (2, 200)]);
}
#[test]
fn two_concurrent_updates_same_row_one_aborts_with_busy() {
let mut a = Connection::open_in_memory().unwrap();
a.execute("PRAGMA journal_mode = mvcc;").unwrap();
a.execute("CREATE TABLE accounts (id INTEGER PRIMARY KEY, balance INTEGER);")
.unwrap();
a.execute("INSERT INTO accounts (id, balance) VALUES (1, 100);")
.unwrap();
let mut b = a.connect();
a.execute("BEGIN CONCURRENT;").unwrap();
b.execute("BEGIN CONCURRENT;").unwrap();
a.execute("UPDATE accounts SET balance = 200 WHERE id = 1;")
.unwrap();
b.execute("UPDATE accounts SET balance = 300 WHERE id = 1;")
.unwrap();
a.execute("COMMIT;").unwrap();
let err = b
.execute("COMMIT;")
.expect_err("second commit must abort with Busy");
assert!(matches!(err, SQLRiteError::Busy(_)));
assert!(err.is_retryable(), "Busy must be retryable");
let msg = format!("{err}");
assert!(
msg.contains("write-write conflict"),
"unexpected error: {msg}"
);
let stmt = a
.prepare("SELECT balance FROM accounts WHERE id = 1;")
.unwrap();
let rows = stmt.query().unwrap().collect_all().unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get::<i64>(0).unwrap(), 200);
}
#[test]
fn aborted_transactions_writes_never_become_visible() {
let mut conn = Connection::open_in_memory().unwrap();
conn.execute("PRAGMA journal_mode = mvcc;").unwrap();
conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY, v INTEGER);")
.unwrap();
conn.execute("INSERT INTO t (id, v) VALUES (1, 100);")
.unwrap();
conn.execute("BEGIN CONCURRENT;").unwrap();
conn.execute("UPDATE t SET v = 999 WHERE id = 1;").unwrap();
conn.execute("ROLLBACK;").unwrap();
let stmt = conn.prepare("SELECT v FROM t WHERE id = 1;").unwrap();
let rows = stmt.query().unwrap().collect_all().unwrap();
assert_eq!(rows[0].get::<i64>(0).unwrap(), 100);
let mut other = conn.connect();
conn.execute("BEGIN CONCURRENT;").unwrap();
other.execute("BEGIN CONCURRENT;").unwrap();
conn.execute("UPDATE t SET v = 7 WHERE id = 1;").unwrap();
other.execute("UPDATE t SET v = 13 WHERE id = 1;").unwrap();
conn.execute("COMMIT;").unwrap();
let _ = other.execute("COMMIT;").expect_err("must abort with Busy");
let rows = conn
.prepare("SELECT v FROM t WHERE id = 1;")
.unwrap()
.query()
.unwrap()
.collect_all()
.unwrap();
assert_eq!(rows[0].get::<i64>(0).unwrap(), 7);
}
#[test]
fn retry_after_busy_succeeds() {
let mut a = Connection::open_in_memory().unwrap();
a.execute("PRAGMA journal_mode = mvcc;").unwrap();
a.execute("CREATE TABLE t (id INTEGER PRIMARY KEY, v INTEGER);")
.unwrap();
a.execute("INSERT INTO t (id, v) VALUES (1, 1);").unwrap();
let mut b = a.connect();
a.execute("BEGIN CONCURRENT;").unwrap();
b.execute("BEGIN CONCURRENT;").unwrap();
a.execute("UPDATE t SET v = 100 WHERE id = 1;").unwrap();
b.execute("UPDATE t SET v = 200 WHERE id = 1;").unwrap();
a.execute("COMMIT;").unwrap();
let err = b.execute("COMMIT;").expect_err("first attempt must Busy");
assert!(err.is_retryable());
b.execute("BEGIN CONCURRENT;").unwrap();
b.execute("UPDATE t SET v = 200 WHERE id = 1;").unwrap();
b.execute("COMMIT;").expect("retry must succeed");
let rows = a
.prepare("SELECT v FROM t WHERE id = 1;")
.unwrap()
.query()
.unwrap()
.collect_all()
.unwrap();
assert_eq!(rows[0].get::<i64>(0).unwrap(), 200);
}
#[test]
fn nested_begin_concurrent_is_rejected() {
let mut conn = Connection::open_in_memory().unwrap();
conn.execute("PRAGMA journal_mode = mvcc;").unwrap();
conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY);")
.unwrap();
conn.execute("BEGIN CONCURRENT;").unwrap();
let err = conn
.execute("BEGIN CONCURRENT;")
.expect_err("nested BEGIN CONCURRENT must error");
assert!(format!("{err}").contains("already open"));
}
#[test]
fn legacy_begin_inside_concurrent_is_rejected() {
let mut conn = Connection::open_in_memory().unwrap();
conn.execute("PRAGMA journal_mode = mvcc;").unwrap();
conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY);")
.unwrap();
conn.execute("BEGIN CONCURRENT;").unwrap();
let err = conn
.execute("BEGIN;")
.expect_err("legacy BEGIN inside concurrent tx must error");
assert!(format!("{err}").contains("concurrent transaction is already open"));
}
#[test]
fn ddl_inside_begin_concurrent_is_rejected() {
let mut conn = Connection::open_in_memory().unwrap();
conn.execute("PRAGMA journal_mode = mvcc;").unwrap();
conn.execute("BEGIN CONCURRENT;").unwrap();
let err = conn
.execute("CREATE TABLE t (id INTEGER PRIMARY KEY);")
.expect_err("DDL inside concurrent tx must error");
let msg = format!("{err}");
assert!(msg.contains("DDL is not supported"), "unexpected: {msg}");
conn.execute("ROLLBACK;").unwrap();
}
#[test]
fn empty_concurrent_commit_never_busies() {
let mut a = Connection::open_in_memory().unwrap();
a.execute("PRAGMA journal_mode = mvcc;").unwrap();
a.execute("CREATE TABLE t (id INTEGER PRIMARY KEY, v INTEGER);")
.unwrap();
a.execute("INSERT INTO t (id, v) VALUES (1, 1);").unwrap();
let mut b = a.connect();
a.execute("BEGIN CONCURRENT;").unwrap();
b.execute("BEGIN CONCURRENT;").unwrap();
b.execute("UPDATE t SET v = 999 WHERE id = 1;").unwrap();
b.execute("COMMIT;").unwrap();
a.execute("COMMIT;")
.expect("empty commit must succeed even if siblings committed");
}
#[test]
fn query_inside_concurrent_tx_sees_begin_time_snapshot() {
let mut a = Connection::open_in_memory().unwrap();
a.execute("PRAGMA journal_mode = mvcc;").unwrap();
a.execute("CREATE TABLE t (id INTEGER PRIMARY KEY, v INTEGER);")
.unwrap();
a.execute("INSERT INTO t (id, v) VALUES (1, 1);").unwrap();
let mut b = a.connect();
a.execute("BEGIN CONCURRENT;").unwrap();
b.execute("BEGIN CONCURRENT;").unwrap();
b.execute("UPDATE t SET v = 999 WHERE id = 1;").unwrap();
b.execute("COMMIT;").unwrap();
let rows = a
.prepare("SELECT v FROM t WHERE id = 1;")
.unwrap()
.query()
.unwrap()
.collect_all()
.unwrap();
assert_eq!(
rows[0].get::<i64>(0).unwrap(),
1,
"Statement::query inside BEGIN CONCURRENT must see the snapshot, not the live db"
);
a.execute("COMMIT;").unwrap();
let rows = a
.prepare("SELECT v FROM t WHERE id = 1;")
.unwrap()
.query()
.unwrap()
.collect_all()
.unwrap();
assert_eq!(rows[0].get::<i64>(0).unwrap(), 999);
}
#[test]
fn query_inside_concurrent_tx_sees_own_writes() {
let mut conn = Connection::open_in_memory().unwrap();
conn.execute("PRAGMA journal_mode = mvcc;").unwrap();
conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY, v INTEGER);")
.unwrap();
conn.execute("INSERT INTO t (id, v) VALUES (1, 100);")
.unwrap();
conn.execute("BEGIN CONCURRENT;").unwrap();
conn.execute("UPDATE t SET v = 200 WHERE id = 1;").unwrap();
let rows = conn
.prepare("SELECT v FROM t WHERE id = 1;")
.unwrap()
.query()
.unwrap()
.collect_all()
.unwrap();
assert_eq!(rows[0].get::<i64>(0).unwrap(), 200);
conn.execute("ROLLBACK;").unwrap();
let rows = conn
.prepare("SELECT v FROM t WHERE id = 1;")
.unwrap()
.query()
.unwrap()
.collect_all()
.unwrap();
assert_eq!(rows[0].get::<i64>(0).unwrap(), 100);
}
#[test]
fn query_with_params_inside_concurrent_tx_sees_snapshot() {
let mut a = Connection::open_in_memory().unwrap();
a.execute("PRAGMA journal_mode = mvcc;").unwrap();
a.execute("CREATE TABLE t (id INTEGER PRIMARY KEY, v INTEGER);")
.unwrap();
a.execute("INSERT INTO t (id, v) VALUES (1, 7);").unwrap();
let mut b = a.connect();
a.execute("BEGIN CONCURRENT;").unwrap();
b.execute("BEGIN CONCURRENT;").unwrap();
b.execute("UPDATE t SET v = 42 WHERE id = 1;").unwrap();
b.execute("COMMIT;").unwrap();
let rows = a
.prepare("SELECT v FROM t WHERE id = ?")
.unwrap()
.query_with_params(&[Value::Integer(1)])
.unwrap()
.collect_all()
.unwrap();
assert_eq!(rows[0].get::<i64>(0).unwrap(), 7);
a.execute("COMMIT;").unwrap();
}
#[test]
fn query_outside_concurrent_tx_sees_live_database() {
let mut a = Connection::open_in_memory().unwrap();
a.execute("PRAGMA journal_mode = mvcc;").unwrap();
a.execute("CREATE TABLE t (id INTEGER PRIMARY KEY, v INTEGER);")
.unwrap();
a.execute("INSERT INTO t (id, v) VALUES (1, 1);").unwrap();
let mut b = a.connect();
b.execute("BEGIN CONCURRENT;").unwrap();
b.execute("UPDATE t SET v = 100 WHERE id = 1;").unwrap();
b.execute("COMMIT;").unwrap();
let rows = a
.prepare("SELECT v FROM t WHERE id = 1;")
.unwrap()
.query()
.unwrap()
.collect_all()
.unwrap();
assert_eq!(rows[0].get::<i64>(0).unwrap(), 100);
}
#[test]
fn snapshot_stays_consistent_across_sibling_commits() {
let mut reader = Connection::open_in_memory().unwrap();
reader.execute("PRAGMA journal_mode = mvcc;").unwrap();
reader
.execute("CREATE TABLE t (id INTEGER PRIMARY KEY, v INTEGER);")
.unwrap();
reader
.execute("INSERT INTO t (id, v) VALUES (1, 1);")
.unwrap();
let mut writer = reader.connect();
reader.execute("BEGIN CONCURRENT;").unwrap();
let read_at_t0 = reader
.prepare("SELECT v FROM t WHERE id = 1;")
.unwrap()
.query()
.unwrap()
.collect_all()
.unwrap();
assert_eq!(read_at_t0[0].get::<i64>(0).unwrap(), 1);
for new_value in [10, 20, 30, 40] {
writer.execute("BEGIN CONCURRENT;").unwrap();
writer
.execute(&format!("UPDATE t SET v = {new_value} WHERE id = 1;"))
.unwrap();
writer.execute("COMMIT;").unwrap();
let r = reader
.prepare("SELECT v FROM t WHERE id = 1;")
.unwrap()
.query()
.unwrap()
.collect_all()
.unwrap();
assert_eq!(
r[0].get::<i64>(0).unwrap(),
1,
"snapshot regressed after writer committed v={new_value}",
);
}
reader.execute("COMMIT;").unwrap();
}
#[test]
fn repeated_updates_keep_chain_bounded_when_no_readers() {
let mut conn = Connection::open_in_memory().unwrap();
conn.execute("PRAGMA journal_mode = mvcc;").unwrap();
conn.execute("CREATE TABLE counters (id INTEGER PRIMARY KEY, n INTEGER);")
.unwrap();
conn.execute("INSERT INTO counters (id, n) VALUES (1, 0);")
.unwrap();
for n in 1..=50 {
conn.execute("BEGIN CONCURRENT;").unwrap();
conn.execute(&format!("UPDATE counters SET n = {n} WHERE id = 1;"))
.unwrap();
conn.execute("COMMIT;").unwrap();
}
let db = conn.database();
let store_size = db.mv_store().total_versions();
let tracked = db.mv_store().tracked_rows();
drop(db);
assert_eq!(
store_size, 1,
"expected 1 version after 50 GC'd updates, got {store_size}",
);
assert_eq!(tracked, 1);
}
#[test]
fn gc_preserves_versions_visible_to_active_reader() {
let mut writer = Connection::open_in_memory().unwrap();
writer.execute("PRAGMA journal_mode = mvcc;").unwrap();
writer
.execute("CREATE TABLE t (id INTEGER PRIMARY KEY, v INTEGER);")
.unwrap();
writer
.execute("INSERT INTO t (id, v) VALUES (1, 0);")
.unwrap();
let mut reader = writer.connect();
reader.execute("BEGIN CONCURRENT;").unwrap();
for n in 1..=5 {
writer.execute("BEGIN CONCURRENT;").unwrap();
writer
.execute(&format!("UPDATE t SET v = {n} WHERE id = 1;"))
.unwrap();
writer.execute("COMMIT;").unwrap();
}
let rows = reader
.prepare("SELECT v FROM t WHERE id = 1;")
.unwrap()
.query()
.unwrap()
.collect_all()
.unwrap();
assert_eq!(rows[0].get::<i64>(0).unwrap(), 0);
reader.execute("COMMIT;").unwrap();
writer.vacuum_mvcc();
let db = writer.database();
let store_size = db.mv_store().total_versions();
drop(db);
assert!(
store_size <= 1,
"after reader closed and vacuum ran, expected ≤1 version, got {store_size}",
);
}
#[test]
fn vacuum_mvcc_is_a_noop_on_wal_database() {
let conn = Connection::open_in_memory().unwrap();
assert_eq!(conn.vacuum_mvcc(), 0);
}
#[test]
fn vacuum_mvcc_reclaims_everything_with_no_active_readers() {
let mut conn = Connection::open_in_memory().unwrap();
conn.execute("PRAGMA journal_mode = mvcc;").unwrap();
conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY, v INTEGER);")
.unwrap();
conn.execute("INSERT INTO t (id, v) VALUES (1, 0);")
.unwrap();
conn.execute("BEGIN CONCURRENT;").unwrap();
conn.execute("UPDATE t SET v = 1 WHERE id = 1;").unwrap();
conn.execute("COMMIT;").unwrap();
conn.execute("BEGIN CONCURRENT;").unwrap();
conn.execute("UPDATE t SET v = 2 WHERE id = 1;").unwrap();
conn.execute("COMMIT;").unwrap();
let _ = conn.vacuum_mvcc();
let db = conn.database();
let store_size = db.mv_store().total_versions();
drop(db);
assert!(store_size <= 1);
}
#[test]
fn is_retryable_covers_busy_variants() {
assert!(SQLRiteError::Busy("x".into()).is_retryable());
assert!(SQLRiteError::BusySnapshot("x".into()).is_retryable());
assert!(!SQLRiteError::General("x".into()).is_retryable());
}
#[test]
fn mvcc_commit_persists_a_log_record_into_wal() {
let path = tmp_path("mvcc_log_record");
{
let mut c = Connection::open(&path).unwrap();
c.execute("PRAGMA journal_mode = mvcc;").unwrap();
c.execute("CREATE TABLE t (id INTEGER PRIMARY KEY, v INTEGER);")
.unwrap();
c.execute("BEGIN CONCURRENT;").unwrap();
c.execute("INSERT INTO t (id, v) VALUES (1, 42);").unwrap();
c.execute("COMMIT;").unwrap();
}
let c2 = Connection::open(&path).unwrap();
let db = c2.database();
let pager = db.pager.as_ref().expect("file-backed db carries a pager");
let batches = pager.recovered_mvcc_commits();
assert_eq!(batches.len(), 1, "one BEGIN CONCURRENT commit -> one batch");
assert_eq!(batches[0].records.len(), 1, "one row written");
let rec = &batches[0].records[0];
assert_eq!(rec.row.table, "t");
assert_eq!(rec.row.rowid, 1);
match &rec.payload {
VersionPayload::Present(cols) => {
assert!(cols.iter().any(
|(k, v)| k == "v" && matches!(v, crate::sql::db::table::Value::Integer(42))
));
}
other => panic!("unexpected payload: {other:?}"),
}
drop(db);
drop(c2);
cleanup(&path);
}
#[test]
fn mvcc_reopen_restores_mv_store_and_clock() {
let path = tmp_path("mvcc_reopen");
{
let mut c = Connection::open(&path).unwrap();
c.execute("PRAGMA journal_mode = mvcc;").unwrap();
c.execute("CREATE TABLE t (id INTEGER PRIMARY KEY, v INTEGER);")
.unwrap();
c.execute("BEGIN CONCURRENT;").unwrap();
c.execute("INSERT INTO t (id, v) VALUES (1, 10);").unwrap();
c.execute("COMMIT;").unwrap();
c.execute("BEGIN CONCURRENT;").unwrap();
c.execute("UPDATE t SET v = 20 WHERE id = 1;").unwrap();
c.execute("COMMIT;").unwrap();
}
let c2 = Connection::open(&path).unwrap();
let db = c2.database();
let store = db.mv_store();
let row = RowID::new("t", 1);
assert!(
store.latest_committed_begin(&row).is_some(),
"MvStore should know about row t/1 after reopen"
);
let last_commit_ts = store.latest_committed_begin(&row).unwrap();
assert!(
db.mvcc_clock().now() >= last_commit_ts,
"clock {} must be >= last replayed commit_ts {}",
db.mvcc_clock().now(),
last_commit_ts,
);
drop(db);
drop(c2);
cleanup(&path);
}
#[test]
fn mvcc_multi_row_batch_replays_intact() {
let path = tmp_path("mvcc_multi_row");
{
let mut c = Connection::open(&path).unwrap();
c.execute("PRAGMA journal_mode = mvcc;").unwrap();
c.execute("CREATE TABLE t (id INTEGER PRIMARY KEY, v INTEGER);")
.unwrap();
c.execute("INSERT INTO t (id, v) VALUES (1, 1);").unwrap();
c.execute("INSERT INTO t (id, v) VALUES (2, 2);").unwrap();
c.execute("INSERT INTO t (id, v) VALUES (3, 3);").unwrap();
c.execute("BEGIN CONCURRENT;").unwrap();
c.execute("UPDATE t SET v = 100 WHERE id = 1;").unwrap();
c.execute("UPDATE t SET v = 200 WHERE id = 2;").unwrap();
c.execute("UPDATE t SET v = 300 WHERE id = 3;").unwrap();
c.execute("COMMIT;").unwrap();
}
let c2 = Connection::open(&path).unwrap();
let db = c2.database();
let pager = db.pager.as_ref().unwrap();
let batches = pager.recovered_mvcc_commits();
assert_eq!(batches.len(), 1, "single COMMIT -> single batch");
let rowids: Vec<i64> = batches[0].records.iter().map(|r| r.row.rowid).collect();
assert!(rowids.contains(&1));
assert!(rowids.contains(&2));
assert!(rowids.contains(&3));
assert_eq!(batches[0].records.len(), 3);
drop(db);
drop(c2);
cleanup(&path);
}
#[test]
fn mvcc_rolled_back_tx_leaves_no_wal_record() {
let path = tmp_path("mvcc_rollback");
{
let mut c = Connection::open(&path).unwrap();
c.execute("PRAGMA journal_mode = mvcc;").unwrap();
c.execute("CREATE TABLE t (id INTEGER PRIMARY KEY, v INTEGER);")
.unwrap();
c.execute("BEGIN CONCURRENT;").unwrap();
c.execute("INSERT INTO t (id, v) VALUES (1, 999);").unwrap();
c.execute("ROLLBACK;").unwrap();
}
let c2 = Connection::open(&path).unwrap();
let db = c2.database();
let pager = db.pager.as_ref().unwrap();
assert!(
pager.recovered_mvcc_commits().is_empty(),
"ROLLBACK must not append MVCC frames"
);
let store = db.mv_store();
assert_eq!(store.total_versions(), 0);
drop(db);
drop(c2);
cleanup(&path);
}
#[test]
fn legacy_commit_does_not_emit_mvcc_frame() {
let path = tmp_path("mvcc_legacy_no_frame");
{
let mut c = Connection::open(&path).unwrap();
c.execute("PRAGMA journal_mode = mvcc;").unwrap();
c.execute("CREATE TABLE t (id INTEGER PRIMARY KEY);")
.unwrap();
c.execute("INSERT INTO t (id) VALUES (1);").unwrap();
}
let c2 = Connection::open(&path).unwrap();
let db = c2.database();
let pager = db.pager.as_ref().unwrap();
assert!(
pager.recovered_mvcc_commits().is_empty(),
"legacy writes never produce MVCC frames"
);
drop(db);
drop(c2);
cleanup(&path);
}
#[test]
fn mvcc_replays_multiple_commits_after_unclean_close() {
let path = tmp_path("mvcc_unclean_close");
{
let mut c = Connection::open(&path).unwrap();
c.execute("PRAGMA journal_mode = mvcc;").unwrap();
c.execute("CREATE TABLE t (id INTEGER PRIMARY KEY, v INTEGER);")
.unwrap();
for v in 0..5 {
c.execute("BEGIN CONCURRENT;").unwrap();
if v == 0 {
c.execute("INSERT INTO t (id, v) VALUES (1, 0);").unwrap();
} else {
c.execute(&format!("UPDATE t SET v = {v} WHERE id = 1;"))
.unwrap();
}
c.execute("COMMIT;").unwrap();
}
}
let c2 = Connection::open(&path).unwrap();
let db = c2.database();
let pager = db.pager.as_ref().unwrap();
let batches = pager.recovered_mvcc_commits();
assert_eq!(batches.len(), 5, "every COMMIT must show up after reopen");
for w in batches.windows(2) {
assert!(w[0].commit_ts < w[1].commit_ts);
}
drop(db);
drop(c2);
cleanup(&path);
}
#[test]
fn prepare_cached_executes_the_same_as_prepare() {
let mut conn = Connection::open_in_memory().unwrap();
conn.execute("CREATE TABLE t (a INTEGER PRIMARY KEY, b TEXT);")
.unwrap();
let mut ins = conn
.prepare_cached("INSERT INTO t (a, b) VALUES (?, ?)")
.unwrap();
ins.execute_with_params(&[Value::Integer(1), Value::Text("alpha".into())])
.unwrap();
ins.execute_with_params(&[Value::Integer(2), Value::Text("beta".into())])
.unwrap();
let stmt = conn.prepare_cached("SELECT b FROM t WHERE a = ?").unwrap();
let rows = stmt
.query_with_params(&[Value::Integer(2)])
.unwrap()
.collect_all()
.unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get::<String>(0).unwrap(), "beta");
}
}