pub mod db;
pub mod executor;
pub mod hnsw;
pub mod pager;
pub mod parser;
use parser::create::CreateQuery;
use parser::insert::InsertQuery;
use parser::select::SelectQuery;
use sqlparser::ast::Statement;
use sqlparser::dialect::SQLiteDialect;
use sqlparser::parser::{Parser, ParserError};
use crate::error::{Result, SQLRiteError};
use crate::sql::db::database::Database;
use crate::sql::db::table::Table;
#[derive(Debug, PartialEq)]
pub enum SQLCommand {
Insert(String),
Delete(String),
Update(String),
CreateTable(String),
Select(String),
Unknown(String),
}
impl SQLCommand {
pub fn new(command: String) -> SQLCommand {
let v = command.split(" ").collect::<Vec<&str>>();
match v[0] {
"insert" => SQLCommand::Insert(command),
"update" => SQLCommand::Update(command),
"delete" => SQLCommand::Delete(command),
"create" => SQLCommand::CreateTable(command),
"select" => SQLCommand::Select(command),
_ => SQLCommand::Unknown(command),
}
}
}
pub fn process_command(query: &str, db: &mut Database) -> Result<String> {
let dialect = SQLiteDialect {};
let message: String;
let mut ast = Parser::parse_sql(&dialect, query).map_err(SQLRiteError::from)?;
if ast.len() > 1 {
return Err(SQLRiteError::SqlError(ParserError::ParserError(format!(
"Expected a single query statement, but there are {}",
ast.len()
))));
}
let Some(query) = ast.pop() else {
return Ok("No statement to execute.".to_string());
};
match &query {
Statement::StartTransaction { .. } => {
db.begin_transaction()?;
return Ok(String::from("BEGIN"));
}
Statement::Commit { .. } => {
if !db.in_transaction() {
return Err(SQLRiteError::General(
"cannot COMMIT: no transaction is open".to_string(),
));
}
if let Some(path) = db.source_path.clone() {
if let Err(save_err) = pager::save_database(db, &path) {
let _ = db.rollback_transaction();
return Err(SQLRiteError::General(format!(
"COMMIT failed — transaction rolled back: {save_err}"
)));
}
}
db.commit_transaction()?;
return Ok(String::from("COMMIT"));
}
Statement::Rollback { .. } => {
db.rollback_transaction()?;
return Ok(String::from("ROLLBACK"));
}
_ => {}
}
let is_write_statement = matches!(
&query,
Statement::CreateTable(_)
| Statement::CreateIndex(_)
| Statement::Insert(_)
| Statement::Update(_)
| Statement::Delete(_)
);
if is_write_statement && db.is_read_only() {
return Err(SQLRiteError::General(
"cannot execute: database is opened read-only".to_string(),
));
}
match query {
Statement::CreateTable(_) => {
let create_query = CreateQuery::new(&query);
match create_query {
Ok(payload) => {
let table_name = payload.table_name.clone();
if table_name == pager::MASTER_TABLE_NAME {
return Err(SQLRiteError::General(format!(
"'{}' is a reserved name used by the internal schema catalog",
pager::MASTER_TABLE_NAME
)));
}
match db.contains_table(table_name.to_string()) {
true => {
return Err(SQLRiteError::Internal(
"Cannot create, table already exists.".to_string(),
));
}
false => {
let table = Table::new(payload);
let _ = table.print_table_schema();
db.tables.insert(table_name.to_string(), table);
message = String::from("CREATE TABLE Statement executed.");
}
}
}
Err(err) => return Err(err),
}
}
Statement::Insert(_) => {
let insert_query = InsertQuery::new(&query);
match insert_query {
Ok(payload) => {
let table_name = payload.table_name;
let columns = payload.columns;
let values = payload.rows;
match db.contains_table(table_name.to_string()) {
true => {
let db_table = db.get_table_mut(table_name.to_string()).unwrap();
match columns
.iter()
.all(|column| db_table.contains_column(column.to_string()))
{
true => {
for value in &values {
if columns.len() != value.len() {
return Err(SQLRiteError::Internal(format!(
"{} values for {} columns",
value.len(),
columns.len()
)));
}
db_table
.validate_unique_constraint(&columns, value)
.map_err(|err| {
SQLRiteError::Internal(format!(
"Unique key constraint violation: {err}"
))
})?;
db_table.insert_row(&columns, value)?;
}
}
false => {
return Err(SQLRiteError::Internal(
"Cannot insert, some of the columns do not exist"
.to_string(),
));
}
}
db_table.print_table_data();
}
false => {
return Err(SQLRiteError::Internal("Table doesn't exist".to_string()));
}
}
}
Err(err) => return Err(err),
}
message = String::from("INSERT Statement executed.")
}
Statement::Query(_) => {
let select_query = SelectQuery::new(&query)?;
let (rendered, rows) = executor::execute_select(select_query, db)?;
print!("{rendered}");
message = format!(
"SELECT Statement executed. {rows} row{s} returned.",
s = if rows == 1 { "" } else { "s" }
);
}
Statement::Delete(_) => {
let rows = executor::execute_delete(&query, db)?;
message = format!(
"DELETE Statement executed. {rows} row{s} deleted.",
s = if rows == 1 { "" } else { "s" }
);
}
Statement::Update(_) => {
let rows = executor::execute_update(&query, db)?;
message = format!(
"UPDATE Statement executed. {rows} row{s} updated.",
s = if rows == 1 { "" } else { "s" }
);
}
Statement::CreateIndex(_) => {
let name = executor::execute_create_index(&query, db)?;
message = format!("CREATE INDEX '{name}' executed.");
}
_ => {
return Err(SQLRiteError::NotImplemented(
"SQL Statement not supported yet.".to_string(),
));
}
};
if is_write_statement && db.source_path.is_some() && !db.in_transaction() {
let path = db.source_path.clone().unwrap();
pager::save_database(db, &path)?;
}
Ok(message)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sql::db::table::Value;
fn seed_users_table() -> Database {
let mut db = Database::new("tempdb".to_string());
process_command(
"CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT NOT NULL, age INTEGER);",
&mut db,
)
.expect("create table");
process_command(
"INSERT INTO users (name, age) VALUES ('alice', 30);",
&mut db,
)
.expect("insert alice");
process_command("INSERT INTO users (name, age) VALUES ('bob', 25);", &mut db)
.expect("insert bob");
process_command(
"INSERT INTO users (name, age) VALUES ('carol', 40);",
&mut db,
)
.expect("insert carol");
db
}
#[test]
fn process_command_select_all_test() {
let mut db = seed_users_table();
let response = process_command("SELECT * FROM users;", &mut db).expect("select");
assert!(response.contains("3 rows returned"));
}
#[test]
fn process_command_select_where_test() {
let mut db = seed_users_table();
let response =
process_command("SELECT name FROM users WHERE age > 25;", &mut db).expect("select");
assert!(response.contains("2 rows returned"));
}
#[test]
fn process_command_select_eq_string_test() {
let mut db = seed_users_table();
let response =
process_command("SELECT name FROM users WHERE name = 'bob';", &mut db).expect("select");
assert!(response.contains("1 row returned"));
}
#[test]
fn process_command_select_limit_test() {
let mut db = seed_users_table();
let response = process_command("SELECT * FROM users ORDER BY age ASC LIMIT 2;", &mut db)
.expect("select");
assert!(response.contains("2 rows returned"));
}
#[test]
fn process_command_select_unknown_table_test() {
let mut db = Database::new("tempdb".to_string());
let result = process_command("SELECT * FROM nope;", &mut db);
assert!(result.is_err());
}
#[test]
fn process_command_select_unknown_column_test() {
let mut db = seed_users_table();
let result = process_command("SELECT height FROM users;", &mut db);
assert!(result.is_err());
}
#[test]
fn process_command_insert_test() {
let mut db = Database::new("tempdb".to_string());
let query_statement = "CREATE TABLE users (
id INTEGER PRIMARY KEY,
name TEXT
);";
let dialect = SQLiteDialect {};
let mut ast = Parser::parse_sql(&dialect, query_statement).unwrap();
if ast.len() > 1 {
panic!("Expected a single query statement, but there are more then 1.")
}
let query = ast.pop().unwrap();
let create_query = CreateQuery::new(&query).unwrap();
db.tables.insert(
create_query.table_name.to_string(),
Table::new(create_query),
);
let insert_query = String::from("INSERT INTO users (name) Values ('josh');");
match process_command(&insert_query, &mut db) {
Ok(response) => assert_eq!(response, "INSERT Statement executed."),
Err(err) => {
eprintln!("Error: {}", err);
assert!(false)
}
};
}
#[test]
fn process_command_insert_no_pk_test() {
let mut db = Database::new("tempdb".to_string());
let query_statement = "CREATE TABLE users (
name TEXT
);";
let dialect = SQLiteDialect {};
let mut ast = Parser::parse_sql(&dialect, query_statement).unwrap();
if ast.len() > 1 {
panic!("Expected a single query statement, but there are more then 1.")
}
let query = ast.pop().unwrap();
let create_query = CreateQuery::new(&query).unwrap();
db.tables.insert(
create_query.table_name.to_string(),
Table::new(create_query),
);
let insert_query = String::from("INSERT INTO users (name) Values ('josh');");
match process_command(&insert_query, &mut db) {
Ok(response) => assert_eq!(response, "INSERT Statement executed."),
Err(err) => {
eprintln!("Error: {}", err);
assert!(false)
}
};
}
#[test]
fn process_command_delete_where_test() {
let mut db = seed_users_table();
let response =
process_command("DELETE FROM users WHERE name = 'bob';", &mut db).expect("delete");
assert!(response.contains("1 row deleted"));
let remaining = process_command("SELECT * FROM users;", &mut db).expect("select");
assert!(remaining.contains("2 rows returned"));
}
#[test]
fn process_command_delete_all_test() {
let mut db = seed_users_table();
let response = process_command("DELETE FROM users;", &mut db).expect("delete");
assert!(response.contains("3 rows deleted"));
}
#[test]
fn process_command_update_where_test() {
use crate::sql::db::table::Value;
let mut db = seed_users_table();
let response = process_command("UPDATE users SET age = 99 WHERE name = 'bob';", &mut db)
.expect("update");
assert!(response.contains("1 row updated"));
let users = db.get_table("users".to_string()).unwrap();
let bob_rowid = users
.rowids()
.into_iter()
.find(|r| users.get_value("name", *r) == Some(Value::Text("bob".to_string())))
.expect("bob row must exist");
assert_eq!(users.get_value("age", bob_rowid), Some(Value::Integer(99)));
}
#[test]
fn process_command_update_unique_violation_test() {
let mut db = seed_users_table();
process_command(
"CREATE TABLE tags (id INTEGER PRIMARY KEY, label TEXT UNIQUE);",
&mut db,
)
.unwrap();
process_command("INSERT INTO tags (label) VALUES ('a');", &mut db).unwrap();
process_command("INSERT INTO tags (label) VALUES ('b');", &mut db).unwrap();
let result = process_command("UPDATE tags SET label = 'a' WHERE label = 'b';", &mut db);
assert!(result.is_err(), "expected UNIQUE violation, got {result:?}");
}
#[test]
fn process_command_insert_type_mismatch_returns_error_test() {
let mut db = Database::new("tempdb".to_string());
process_command(
"CREATE TABLE items (id INTEGER PRIMARY KEY, qty INTEGER);",
&mut db,
)
.unwrap();
let result = process_command("INSERT INTO items (qty) VALUES ('not a number');", &mut db);
assert!(result.is_err(), "expected error, got {result:?}");
}
#[test]
fn process_command_insert_missing_integer_returns_error_test() {
let mut db = Database::new("tempdb".to_string());
process_command(
"CREATE TABLE items (id INTEGER PRIMARY KEY, qty INTEGER);",
&mut db,
)
.unwrap();
let result = process_command("INSERT INTO items (id) VALUES (1);", &mut db);
assert!(result.is_err(), "expected error, got {result:?}");
}
#[test]
fn process_command_update_arith_test() {
use crate::sql::db::table::Value;
let mut db = seed_users_table();
process_command("UPDATE users SET age = age + 1;", &mut db).expect("update +1");
let users = db.get_table("users".to_string()).unwrap();
let mut ages: Vec<i64> = users
.rowids()
.into_iter()
.filter_map(|r| match users.get_value("age", r) {
Some(Value::Integer(n)) => Some(n),
_ => None,
})
.collect();
ages.sort();
assert_eq!(ages, vec![26, 31, 41]); }
#[test]
fn process_command_select_arithmetic_where_test() {
let mut db = seed_users_table();
let response =
process_command("SELECT name FROM users WHERE age * 2 > 55;", &mut db).expect("select");
assert!(response.contains("2 rows returned"));
}
#[test]
fn process_command_divide_by_zero_test() {
let mut db = seed_users_table();
let result = process_command("SELECT age / 0 FROM users;", &mut db);
assert!(result.is_err());
}
#[test]
fn process_command_unsupported_statement_test() {
let mut db = Database::new("tempdb".to_string());
let result = process_command("DROP TABLE users;", &mut db);
assert!(result.is_err());
}
#[test]
fn empty_input_is_a_noop_not_a_panic() {
let mut db = Database::new("t".to_string());
for input in ["", " ", "-- just a comment", "-- comment\n-- another"] {
let result = process_command(input, &mut db);
assert!(result.is_ok(), "input {input:?} should not error");
let msg = result.unwrap();
assert!(msg.contains("No statement"), "got: {msg:?}");
}
}
#[test]
fn create_index_adds_explicit_index() {
let mut db = seed_users_table();
let response = process_command("CREATE INDEX users_age_idx ON users (age);", &mut db)
.expect("create index");
assert!(response.contains("users_age_idx"));
let users = db.get_table("users".to_string()).unwrap();
let idx = users
.index_by_name("users_age_idx")
.expect("index should exist after CREATE INDEX");
assert_eq!(idx.column_name, "age");
assert!(!idx.is_unique);
}
#[test]
fn create_unique_index_rejects_duplicate_existing_values() {
let mut db = seed_users_table();
process_command("INSERT INTO users (name, age) VALUES ('dan', 30);", &mut db).unwrap();
let result = process_command(
"CREATE UNIQUE INDEX users_age_unique ON users (age);",
&mut db,
);
assert!(
result.is_err(),
"expected unique-index failure, got {result:?}"
);
}
#[test]
fn where_eq_on_indexed_column_uses_index_probe() {
let mut db = Database::new("t".to_string());
process_command(
"CREATE TABLE big (id INTEGER PRIMARY KEY, tag TEXT);",
&mut db,
)
.unwrap();
process_command("CREATE INDEX big_tag_idx ON big (tag);", &mut db).unwrap();
for i in 1..=100 {
let tag = if i % 3 == 0 { "hot" } else { "cold" };
process_command(&format!("INSERT INTO big (tag) VALUES ('{tag}');"), &mut db).unwrap();
}
let response =
process_command("SELECT id FROM big WHERE tag = 'hot';", &mut db).expect("select");
assert!(
response.contains("33 rows returned"),
"response was {response:?}"
);
}
#[test]
fn where_eq_on_indexed_column_inside_parens_uses_index_probe() {
let mut db = seed_users_table();
let response = process_command("SELECT name FROM users WHERE (name = 'bob');", &mut db)
.expect("select");
assert!(response.contains("1 row returned"));
}
#[test]
fn where_eq_literal_first_side_uses_index_probe() {
let mut db = seed_users_table();
let response =
process_command("SELECT name FROM users WHERE 'bob' = name;", &mut db).expect("select");
assert!(response.contains("1 row returned"));
}
#[test]
fn non_equality_where_still_falls_back_to_full_scan() {
let mut db = seed_users_table();
let response =
process_command("SELECT name FROM users WHERE age > 28;", &mut db).expect("select");
assert!(response.contains("2 rows returned"));
}
#[test]
fn rollback_restores_pre_begin_in_memory_state() {
let mut db = seed_users_table();
let before = db.get_table("users".to_string()).unwrap().rowids().len();
assert_eq!(before, 3);
process_command("BEGIN;", &mut db).expect("BEGIN");
assert!(db.in_transaction());
process_command("INSERT INTO users (name, age) VALUES ('dan', 50);", &mut db)
.expect("INSERT inside txn");
let mid = db.get_table("users".to_string()).unwrap().rowids().len();
assert_eq!(mid, 4);
process_command("ROLLBACK;", &mut db).expect("ROLLBACK");
assert!(!db.in_transaction());
let after = db.get_table("users".to_string()).unwrap().rowids().len();
assert_eq!(
after, 3,
"ROLLBACK should have restored the pre-BEGIN state"
);
}
#[test]
fn commit_keeps_mutations_and_clears_txn_flag() {
let mut db = seed_users_table();
process_command("BEGIN;", &mut db).expect("BEGIN");
process_command("INSERT INTO users (name, age) VALUES ('dan', 50);", &mut db)
.expect("INSERT inside txn");
process_command("COMMIT;", &mut db).expect("COMMIT");
assert!(!db.in_transaction());
let after = db.get_table("users".to_string()).unwrap().rowids().len();
assert_eq!(after, 4);
}
#[test]
fn rollback_undoes_update_and_delete_side_by_side() {
use crate::sql::db::table::Value;
let mut db = seed_users_table();
process_command("BEGIN;", &mut db).unwrap();
process_command("UPDATE users SET age = 999;", &mut db).unwrap();
process_command("DELETE FROM users WHERE name = 'bob';", &mut db).unwrap();
let users = db.get_table("users".to_string()).unwrap();
assert_eq!(users.rowids().len(), 2);
for r in users.rowids() {
assert_eq!(users.get_value("age", r), Some(Value::Integer(999)));
}
process_command("ROLLBACK;", &mut db).unwrap();
let users = db.get_table("users".to_string()).unwrap();
assert_eq!(users.rowids().len(), 3);
for r in users.rowids() {
assert_ne!(users.get_value("age", r), Some(Value::Integer(999)));
}
}
#[test]
fn nested_begin_is_rejected() {
let mut db = seed_users_table();
process_command("BEGIN;", &mut db).unwrap();
let err = process_command("BEGIN;", &mut db).unwrap_err();
assert!(
format!("{err}").contains("already open"),
"nested BEGIN should error; got: {err}"
);
assert!(db.in_transaction());
process_command("ROLLBACK;", &mut db).unwrap();
}
#[test]
fn orphan_commit_and_rollback_are_rejected() {
let mut db = seed_users_table();
let commit_err = process_command("COMMIT;", &mut db).unwrap_err();
assert!(format!("{commit_err}").contains("no transaction"));
let rollback_err = process_command("ROLLBACK;", &mut db).unwrap_err();
assert!(format!("{rollback_err}").contains("no transaction"));
}
#[test]
fn error_inside_transaction_keeps_txn_open() {
let mut db = seed_users_table();
process_command("BEGIN;", &mut db).unwrap();
let err = process_command("INSERT INTO nope (x) VALUES (1);", &mut db);
assert!(err.is_err());
assert!(db.in_transaction(), "txn should stay open after error");
process_command("ROLLBACK;", &mut db).unwrap();
}
fn seed_file_backed(name: &str, schema: &str) -> (std::path::PathBuf, Database) {
use crate::sql::pager::{open_database, save_database};
let mut p = std::env::temp_dir();
let pid = std::process::id();
let nanos = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos())
.unwrap_or(0);
p.push(format!("sqlrite-txn-{name}-{pid}-{nanos}.sqlrite"));
{
let mut seed = Database::new("t".to_string());
process_command(schema, &mut seed).unwrap();
save_database(&mut seed, &p).unwrap();
}
let db = open_database(&p, "t".to_string()).unwrap();
(p, db)
}
fn cleanup_file(path: &std::path::Path) {
let _ = std::fs::remove_file(path);
let mut wal = path.as_os_str().to_owned();
wal.push("-wal");
let _ = std::fs::remove_file(std::path::PathBuf::from(wal));
}
#[test]
fn begin_commit_rollback_round_trip_through_disk() {
use crate::sql::pager::open_database;
let (path, mut db) = seed_file_backed(
"roundtrip",
"CREATE TABLE notes (id INTEGER PRIMARY KEY, body TEXT);",
);
process_command("BEGIN;", &mut db).unwrap();
process_command("INSERT INTO notes (body) VALUES ('a');", &mut db).unwrap();
process_command("INSERT INTO notes (body) VALUES ('b');", &mut db).unwrap();
process_command("COMMIT;", &mut db).unwrap();
process_command("BEGIN;", &mut db).unwrap();
process_command("INSERT INTO notes (body) VALUES ('c');", &mut db).unwrap();
process_command("ROLLBACK;", &mut db).unwrap();
drop(db);
let reopened = open_database(&path, "t".to_string()).unwrap();
let notes = reopened.get_table("notes".to_string()).unwrap();
assert_eq!(notes.rowids().len(), 2, "committed rows should survive");
drop(reopened);
cleanup_file(&path);
}
#[test]
fn write_inside_transaction_does_not_autosave() {
let (path, mut db) =
seed_file_backed("noas", "CREATE TABLE t (id INTEGER PRIMARY KEY, x TEXT);");
let mut wal_path = path.as_os_str().to_owned();
wal_path.push("-wal");
let wal_path = std::path::PathBuf::from(wal_path);
let frames_before = std::fs::metadata(&wal_path).unwrap().len();
process_command("BEGIN;", &mut db).unwrap();
process_command("INSERT INTO t (x) VALUES ('a');", &mut db).unwrap();
process_command("INSERT INTO t (x) VALUES ('b');", &mut db).unwrap();
let frames_mid = std::fs::metadata(&wal_path).unwrap().len();
assert_eq!(
frames_before, frames_mid,
"WAL should not grow during an open transaction"
);
process_command("COMMIT;", &mut db).unwrap();
drop(db); let fresh = crate::sql::pager::open_database(&path, "t".to_string()).unwrap();
assert_eq!(
fresh.get_table("t".to_string()).unwrap().rowids().len(),
2,
"COMMIT should have persisted both inserted rows"
);
drop(fresh);
cleanup_file(&path);
}
#[test]
fn rollback_undoes_create_table() {
let mut db = seed_users_table();
assert_eq!(db.tables.len(), 1);
process_command("BEGIN;", &mut db).unwrap();
process_command(
"CREATE TABLE dropme (id INTEGER PRIMARY KEY, x TEXT);",
&mut db,
)
.unwrap();
process_command("INSERT INTO dropme (x) VALUES ('stuff');", &mut db).unwrap();
assert_eq!(db.tables.len(), 2);
process_command("ROLLBACK;", &mut db).unwrap();
assert_eq!(
db.tables.len(),
1,
"CREATE TABLE should have been rolled back"
);
assert!(db.get_table("dropme".to_string()).is_err());
}
#[test]
fn rollback_restores_secondary_index_state() {
let mut db = Database::new("t".to_string());
process_command(
"CREATE TABLE users (id INTEGER PRIMARY KEY, email TEXT UNIQUE);",
&mut db,
)
.unwrap();
process_command("INSERT INTO users (email) VALUES ('a@x');", &mut db).unwrap();
process_command("BEGIN;", &mut db).unwrap();
process_command("INSERT INTO users (email) VALUES ('b@x');", &mut db).unwrap();
process_command("ROLLBACK;", &mut db).unwrap();
let reinsert = process_command("INSERT INTO users (email) VALUES ('b@x');", &mut db);
assert!(
reinsert.is_ok(),
"re-insert after rollback should succeed, got {reinsert:?}"
);
}
#[test]
fn rollback_restores_last_rowid_counter() {
use crate::sql::db::table::Value;
let mut db = seed_users_table(); let pre = db.get_table("users".to_string()).unwrap().last_rowid;
process_command("BEGIN;", &mut db).unwrap();
process_command("INSERT INTO users (name, age) VALUES ('d', 50);", &mut db).unwrap(); process_command("INSERT INTO users (name, age) VALUES ('e', 60);", &mut db).unwrap(); process_command("ROLLBACK;", &mut db).unwrap();
let post = db.get_table("users".to_string()).unwrap().last_rowid;
assert_eq!(pre, post, "last_rowid must roll back with the snapshot");
process_command("INSERT INTO users (name, age) VALUES ('d', 50);", &mut db).unwrap();
let users = db.get_table("users".to_string()).unwrap();
let d_rowid = users
.rowids()
.into_iter()
.find(|r| users.get_value("name", *r) == Some(Value::Text("d".into())))
.expect("d row must exist");
assert_eq!(d_rowid, pre + 1);
}
#[test]
fn commit_on_in_memory_db_clears_txn_without_pager_call() {
let mut db = seed_users_table(); assert!(db.source_path.is_none());
process_command("BEGIN;", &mut db).unwrap();
process_command("INSERT INTO users (name, age) VALUES ('z', 99);", &mut db).unwrap();
process_command("COMMIT;", &mut db).unwrap();
assert!(!db.in_transaction());
assert_eq!(db.get_table("users".to_string()).unwrap().rowids().len(), 4);
}
#[test]
fn failed_commit_auto_rolls_back_in_memory_state() {
use crate::sql::pager::save_database;
let (path, mut db) = seed_file_backed(
"failcommit",
"CREATE TABLE notes (id INTEGER PRIMARY KEY, body TEXT);",
);
process_command("INSERT INTO notes (body) VALUES ('before');", &mut db).unwrap();
process_command("BEGIN;", &mut db).unwrap();
process_command("INSERT INTO notes (body) VALUES ('inflight');", &mut db).unwrap();
assert_eq!(
db.get_table("notes".to_string()).unwrap().rowids().len(),
2,
"inflight row visible mid-txn"
);
let orig_source = db.source_path.clone();
let orig_pager = db.pager.take();
db.source_path = Some(std::env::temp_dir());
let commit_result = process_command("COMMIT;", &mut db);
assert!(commit_result.is_err(), "commit must fail");
let err_str = format!("{}", commit_result.unwrap_err());
assert!(
err_str.contains("COMMIT failed") && err_str.contains("rolled back"),
"error must surface auto-rollback; got: {err_str}"
);
assert!(
!db.in_transaction(),
"txn must be cleared after auto-rollback"
);
assert_eq!(
db.get_table("notes".to_string()).unwrap().rowids().len(),
1,
"inflight row must be rolled back"
);
db.source_path = orig_source;
db.pager = orig_pager;
process_command("INSERT INTO notes (body) VALUES ('after');", &mut db).unwrap();
drop(db);
let reopened = crate::sql::pager::open_database(&path, "t".to_string()).unwrap();
let notes = reopened.get_table("notes".to_string()).unwrap();
assert_eq!(notes.rowids().len(), 2);
let _ = save_database; drop(reopened);
cleanup_file(&path);
}
#[test]
fn begin_on_read_only_is_rejected() {
use crate::sql::pager::{open_database_read_only, save_database};
let path = {
let mut p = std::env::temp_dir();
let pid = std::process::id();
let nanos = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos())
.unwrap_or(0);
p.push(format!("sqlrite-txn-ro-{pid}-{nanos}.sqlrite"));
p
};
{
let mut seed = Database::new("t".to_string());
process_command("CREATE TABLE t (id INTEGER PRIMARY KEY);", &mut seed).unwrap();
save_database(&mut seed, &path).unwrap();
}
let mut ro = open_database_read_only(&path, "t".to_string()).unwrap();
let err = process_command("BEGIN;", &mut ro).unwrap_err();
assert!(
format!("{err}").contains("read-only"),
"BEGIN on RO db should surface read-only; got: {err}"
);
assert!(!ro.in_transaction());
let _ = std::fs::remove_file(&path);
let mut wal = path.as_os_str().to_owned();
wal.push("-wal");
let _ = std::fs::remove_file(std::path::PathBuf::from(wal));
}
#[test]
fn read_only_database_rejects_mutations_before_touching_state() {
use crate::sql::pager::open_database_read_only;
let mut seed = Database::new("t".to_string());
process_command(
"CREATE TABLE notes (id INTEGER PRIMARY KEY, body TEXT);",
&mut seed,
)
.unwrap();
process_command("INSERT INTO notes (body) VALUES ('alpha');", &mut seed).unwrap();
let path = {
let mut p = std::env::temp_dir();
let pid = std::process::id();
let nanos = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos())
.unwrap_or(0);
p.push(format!("sqlrite-ro-reject-{pid}-{nanos}.sqlrite"));
p
};
crate::sql::pager::save_database(&mut seed, &path).unwrap();
drop(seed);
let mut ro = open_database_read_only(&path, "t".to_string()).unwrap();
let notes_before = ro.get_table("notes".to_string()).unwrap().rowids().len();
for stmt in [
"INSERT INTO notes (body) VALUES ('beta');",
"UPDATE notes SET body = 'x';",
"DELETE FROM notes;",
"CREATE TABLE more (id INTEGER PRIMARY KEY);",
"CREATE INDEX notes_body ON notes (body);",
] {
let err = process_command(stmt, &mut ro).unwrap_err();
assert!(
format!("{err}").contains("read-only"),
"stmt {stmt:?} should surface a read-only error; got: {err}"
);
}
let notes_after = ro.get_table("notes".to_string()).unwrap().rowids().len();
assert_eq!(notes_before, notes_after);
let sel = process_command("SELECT * FROM notes;", &mut ro).expect("select on RO must work");
assert!(sel.contains("1 row returned"));
drop(ro);
let _ = std::fs::remove_file(&path);
let mut wal = path.as_os_str().to_owned();
wal.push("-wal");
let _ = std::fs::remove_file(std::path::PathBuf::from(wal));
}
#[test]
fn vector_create_table_and_insert_basic() {
let mut db = Database::new("tempdb".to_string());
process_command(
"CREATE TABLE docs (id INTEGER PRIMARY KEY, embedding VECTOR(3));",
&mut db,
)
.expect("create table with VECTOR(3)");
process_command(
"INSERT INTO docs (embedding) VALUES ([0.1, 0.2, 0.3]);",
&mut db,
)
.expect("insert vector");
let sel = process_command("SELECT * FROM docs;", &mut db).expect("select");
assert!(sel.contains("1 row returned"));
let docs = db.get_table("docs".to_string()).expect("docs table");
let rowids = docs.rowids();
assert_eq!(rowids.len(), 1);
match docs.get_value("embedding", rowids[0]) {
Some(Value::Vector(v)) => assert_eq!(v, vec![0.1f32, 0.2, 0.3]),
other => panic!("expected Value::Vector(...), got {other:?}"),
}
}
#[test]
fn vector_dim_mismatch_at_insert_is_clean_error() {
let mut db = Database::new("tempdb".to_string());
process_command(
"CREATE TABLE docs (id INTEGER PRIMARY KEY, embedding VECTOR(3));",
&mut db,
)
.expect("create table");
let err = process_command("INSERT INTO docs (embedding) VALUES ([0.1, 0.2]);", &mut db)
.unwrap_err();
let msg = format!("{err}");
assert!(
msg.to_lowercase().contains("dimension")
&& msg.contains("declared 3")
&& msg.contains("got 2"),
"expected clear dim-mismatch error, got: {msg}"
);
let err = process_command(
"INSERT INTO docs (embedding) VALUES ([0.1, 0.2, 0.3, 0.4, 0.5]);",
&mut db,
)
.unwrap_err();
assert!(
format!("{err}").contains("got 5"),
"expected dim-mismatch error mentioning got 5, got: {err}"
);
}
#[test]
fn vector_create_table_rejects_missing_dim() {
let mut db = Database::new("tempdb".to_string());
let result = process_command(
"CREATE TABLE docs (id INTEGER PRIMARY KEY, embedding VECTOR);",
&mut db,
);
assert!(
result.is_err(),
"expected CREATE TABLE with bare VECTOR to fail (no dim)"
);
}
#[test]
fn vector_create_table_rejects_zero_dim() {
let mut db = Database::new("tempdb".to_string());
let err = process_command(
"CREATE TABLE docs (id INTEGER PRIMARY KEY, embedding VECTOR(0));",
&mut db,
)
.unwrap_err();
let msg = format!("{err}");
assert!(
msg.to_lowercase().contains("vector"),
"expected VECTOR-related error for VECTOR(0), got: {msg}"
);
}
#[test]
fn vector_high_dim_works() {
let mut db = Database::new("tempdb".to_string());
process_command(
"CREATE TABLE embeddings (id INTEGER PRIMARY KEY, e VECTOR(384));",
&mut db,
)
.expect("create table VECTOR(384)");
let lit = format!(
"[{}]",
(0..384)
.map(|i| format!("{}", i as f32 * 0.001))
.collect::<Vec<_>>()
.join(",")
);
let sql = format!("INSERT INTO embeddings (e) VALUES ({lit});");
process_command(&sql, &mut db).expect("insert 384-dim vector");
let sel = process_command("SELECT id FROM embeddings;", &mut db).expect("select id");
assert!(sel.contains("1 row returned"));
}
#[test]
fn vector_multiple_rows() {
let mut db = Database::new("tempdb".to_string());
process_command(
"CREATE TABLE docs (id INTEGER PRIMARY KEY, e VECTOR(2));",
&mut db,
)
.expect("create");
for i in 0..3 {
let sql = format!("INSERT INTO docs (e) VALUES ([{i}.0, {}.0]);", i + 1);
process_command(&sql, &mut db).expect("insert");
}
let sel = process_command("SELECT * FROM docs;", &mut db).expect("select");
assert!(sel.contains("3 rows returned"));
let docs = db.get_table("docs".to_string()).expect("docs table");
let rowids = docs.rowids();
assert_eq!(rowids.len(), 3);
let mut vectors: Vec<Vec<f32>> = rowids
.iter()
.filter_map(|r| match docs.get_value("e", *r) {
Some(Value::Vector(v)) => Some(v),
_ => None,
})
.collect();
vectors.sort_by(|a, b| a[0].partial_cmp(&b[0]).unwrap());
assert_eq!(vectors[0], vec![0.0f32, 1.0]);
assert_eq!(vectors[1], vec![1.0f32, 2.0]);
assert_eq!(vectors[2], vec![2.0f32, 3.0]);
}
fn seed_hnsw_table() -> Database {
let mut db = Database::new("tempdb".to_string());
process_command(
"CREATE TABLE docs (id INTEGER PRIMARY KEY, e VECTOR(2));",
&mut db,
)
.unwrap();
for v in &[
"[1.0, 0.0]", "[2.0, 0.0]", "[0.0, 3.0]", "[1.0, 4.0]", "[10.0, 10.0]", ] {
process_command(&format!("INSERT INTO docs (e) VALUES ({v});"), &mut db).unwrap();
}
db
}
#[test]
fn create_index_using_hnsw_succeeds() {
let mut db = seed_hnsw_table();
let resp = process_command("CREATE INDEX ix_e ON docs USING hnsw (e);", &mut db).unwrap();
assert!(resp.to_lowercase().contains("create index"));
let table = db.get_table("docs".to_string()).unwrap();
assert_eq!(table.hnsw_indexes.len(), 1);
assert_eq!(table.hnsw_indexes[0].name, "ix_e");
assert_eq!(table.hnsw_indexes[0].column_name, "e");
assert_eq!(table.hnsw_indexes[0].index.len(), 5);
}
#[test]
fn create_index_using_hnsw_rejects_non_vector_column() {
let mut db = Database::new("tempdb".to_string());
process_command(
"CREATE TABLE t (id INTEGER PRIMARY KEY, name TEXT);",
&mut db,
)
.unwrap();
let err =
process_command("CREATE INDEX ix_name ON t USING hnsw (name);", &mut db).unwrap_err();
let msg = format!("{err}");
assert!(
msg.to_lowercase().contains("vector"),
"expected error mentioning VECTOR; got: {msg}"
);
}
#[test]
fn knn_query_uses_hnsw_after_create_index() {
let mut db = seed_hnsw_table();
process_command("CREATE INDEX ix_e ON docs USING hnsw (e);", &mut db).unwrap();
let resp = process_command(
"SELECT id FROM docs ORDER BY vec_distance_l2(e, [1.0, 0.0]) ASC LIMIT 3;",
&mut db,
)
.unwrap();
assert!(resp.contains("3 rows returned"), "got: {resp}");
}
#[test]
fn knn_query_works_after_subsequent_inserts() {
let mut db = seed_hnsw_table();
process_command("CREATE INDEX ix_e ON docs USING hnsw (e);", &mut db).unwrap();
process_command("INSERT INTO docs (e) VALUES ([0.5, 0.0]);", &mut db).unwrap(); process_command("INSERT INTO docs (e) VALUES ([0.1, 0.1]);", &mut db).unwrap();
let table = db.get_table("docs".to_string()).unwrap();
assert_eq!(
table.hnsw_indexes[0].index.len(),
7,
"incremental insert should grow HNSW alongside row storage"
);
let resp = process_command(
"SELECT id FROM docs ORDER BY vec_distance_l2(e, [0.0, 0.0]) ASC LIMIT 1;",
&mut db,
)
.unwrap();
assert!(resp.contains("1 row returned"), "got: {resp}");
}
#[test]
fn delete_on_hnsw_indexed_table_succeeds_and_marks_dirty() {
let mut db = seed_hnsw_table();
process_command("CREATE INDEX ix_e ON docs USING hnsw (e);", &mut db).unwrap();
let resp = process_command("DELETE FROM docs WHERE id = 1;", &mut db).unwrap();
assert!(resp.contains("1 row"), "expected 1 row deleted: {resp}");
let docs = db.get_table("docs".to_string()).unwrap();
let entry = docs.hnsw_indexes.iter().find(|e| e.name == "ix_e").unwrap();
assert!(
entry.needs_rebuild,
"DELETE should have marked HNSW index dirty for rebuild on next save"
);
}
#[test]
fn update_on_hnsw_indexed_vector_col_succeeds_and_marks_dirty() {
let mut db = seed_hnsw_table();
process_command("CREATE INDEX ix_e ON docs USING hnsw (e);", &mut db).unwrap();
let resp =
process_command("UPDATE docs SET e = [9.0, 9.0] WHERE id = 1;", &mut db).unwrap();
assert!(resp.contains("1 row"), "expected 1 row updated: {resp}");
let docs = db.get_table("docs".to_string()).unwrap();
let entry = docs.hnsw_indexes.iter().find(|e| e.name == "ix_e").unwrap();
assert!(
entry.needs_rebuild,
"UPDATE on the vector column should have marked HNSW index dirty"
);
}
#[test]
fn duplicate_index_name_errors() {
let mut db = seed_hnsw_table();
process_command("CREATE INDEX ix_e ON docs USING hnsw (e);", &mut db).unwrap();
let err =
process_command("CREATE INDEX ix_e ON docs USING hnsw (e);", &mut db).unwrap_err();
let msg = format!("{err}");
assert!(
msg.to_lowercase().contains("already exists"),
"expected duplicate-index error; got: {msg}"
);
}
#[test]
fn index_if_not_exists_is_idempotent() {
let mut db = seed_hnsw_table();
process_command("CREATE INDEX ix_e ON docs USING hnsw (e);", &mut db).unwrap();
process_command(
"CREATE INDEX IF NOT EXISTS ix_e ON docs USING hnsw (e);",
&mut db,
)
.unwrap();
let table = db.get_table("docs".to_string()).unwrap();
assert_eq!(table.hnsw_indexes.len(), 1);
}
fn seed_vector_docs() -> Database {
let mut db = Database::new("tempdb".to_string());
process_command(
"CREATE TABLE docs (id INTEGER PRIMARY KEY, e VECTOR(2));",
&mut db,
)
.expect("create");
process_command("INSERT INTO docs (e) VALUES ([1.0, 0.0]);", &mut db).expect("insert 1");
process_command("INSERT INTO docs (e) VALUES ([0.0, 1.0]);", &mut db).expect("insert 2");
process_command("INSERT INTO docs (e) VALUES ([1.0, 1.0]);", &mut db).expect("insert 3");
db
}
#[test]
fn vec_distance_l2_in_where_filters_correctly() {
let mut db = seed_vector_docs();
let resp = process_command(
"SELECT * FROM docs WHERE vec_distance_l2(e, [1.0, 0.0]) < 1.1;",
&mut db,
)
.expect("select");
assert!(
resp.contains("2 rows returned"),
"expected 2 rows, got: {resp}"
);
}
#[test]
fn vec_distance_cosine_in_where() {
let mut db = seed_vector_docs();
let resp = process_command(
"SELECT * FROM docs WHERE vec_distance_cosine(e, [1.0, 0.0]) < 0.5;",
&mut db,
)
.expect("select");
assert!(
resp.contains("2 rows returned"),
"expected 2 rows, got: {resp}"
);
}
#[test]
fn vec_distance_dot_negated() {
let mut db = seed_vector_docs();
let resp = process_command(
"SELECT * FROM docs WHERE vec_distance_dot(e, [1.0, 0.0]) < 0.0;",
&mut db,
)
.expect("select");
assert!(
resp.contains("2 rows returned"),
"expected 2 rows, got: {resp}"
);
}
#[test]
fn knn_via_order_by_distance_limit() {
let mut db = seed_vector_docs();
let resp = process_command(
"SELECT id FROM docs ORDER BY vec_distance_l2(e, [1.0, 0.0]) ASC LIMIT 2;",
&mut db,
)
.expect("select");
assert!(
resp.contains("2 rows returned"),
"expected 2 rows, got: {resp}"
);
}
#[test]
fn distance_function_dim_mismatch_errors() {
let mut db = seed_vector_docs();
let err = process_command(
"SELECT * FROM docs WHERE vec_distance_l2(e, [1.0, 0.0, 0.0]) < 1.0;",
&mut db,
)
.unwrap_err();
let msg = format!("{err}");
assert!(
msg.to_lowercase().contains("dimension")
&& msg.contains("lhs=2")
&& msg.contains("rhs=3"),
"expected dim mismatch error, got: {msg}"
);
}
#[test]
fn unknown_function_errors_with_name() {
let mut db = seed_vector_docs();
let err = process_command(
"SELECT * FROM docs WHERE vec_does_not_exist(e, [1.0, 0.0]) < 1.0;",
&mut db,
)
.unwrap_err();
let msg = format!("{err}");
assert!(
msg.contains("vec_does_not_exist"),
"expected error mentioning function name, got: {msg}"
);
}
}