use std::path::Path;
use sqlparser::dialect::SQLiteDialect;
use sqlparser::parser::Parser;
use crate::error::{Result, SQLRiteError};
use crate::sql::db::database::Database;
use crate::sql::db::table::Value;
use crate::sql::executor::execute_select_rows;
use crate::sql::pager::{AccessMode, open_database_with_mode, save_database};
use crate::sql::parser::select::SelectQuery;
use crate::sql::process_command;
pub struct Connection {
db: Database,
}
impl Connection {
pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
let path = path.as_ref();
let db_name = path
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("db")
.to_string();
let db = if path.exists() {
open_database_with_mode(path, db_name, AccessMode::ReadWrite)?
} else {
let mut fresh = Database::new(db_name);
fresh.source_path = Some(path.to_path_buf());
save_database(&mut fresh, path)?;
fresh
};
Ok(Self { db })
}
pub fn open_read_only<P: AsRef<Path>>(path: P) -> Result<Self> {
let path = path.as_ref();
let db_name = path
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("db")
.to_string();
let db = open_database_with_mode(path, db_name, AccessMode::ReadOnly)?;
Ok(Self { db })
}
pub fn open_in_memory() -> Result<Self> {
Ok(Self {
db: Database::new("memdb".to_string()),
})
}
pub fn execute(&mut self, sql: &str) -> Result<String> {
process_command(sql, &mut self.db)
}
pub fn prepare<'c>(&'c mut self, sql: &str) -> Result<Statement<'c>> {
Statement::new(self, sql)
}
pub fn in_transaction(&self) -> bool {
self.db.in_transaction()
}
pub fn is_read_only(&self) -> bool {
self.db.is_read_only()
}
#[doc(hidden)]
pub fn database(&self) -> &Database {
&self.db
}
#[doc(hidden)]
pub fn database_mut(&mut self) -> &mut Database {
&mut self.db
}
}
impl std::fmt::Debug for Connection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Connection")
.field("in_transaction", &self.db.in_transaction())
.field("read_only", &self.db.is_read_only())
.field("tables", &self.db.tables.len())
.finish()
}
}
pub struct Statement<'c> {
conn: &'c mut Connection,
sql: String,
kind: StatementKind,
}
enum StatementKind {
Select(SelectQuery),
Other,
}
impl std::fmt::Debug for Statement<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Statement")
.field("sql", &self.sql)
.field(
"kind",
&match self.kind {
StatementKind::Select(_) => "Select",
StatementKind::Other => "Other",
},
)
.finish()
}
}
impl<'c> Statement<'c> {
fn new(conn: &'c mut Connection, sql: &str) -> Result<Self> {
let dialect = SQLiteDialect {};
let mut ast = Parser::parse_sql(&dialect, sql).map_err(SQLRiteError::from)?;
let Some(stmt) = ast.pop() else {
return Err(SQLRiteError::General("no statement to prepare".to_string()));
};
if !ast.is_empty() {
return Err(SQLRiteError::General(
"prepare() accepts a single statement; found more than one".to_string(),
));
}
let kind = match &stmt {
sqlparser::ast::Statement::Query(_) => StatementKind::Select(SelectQuery::new(&stmt)?),
_ => StatementKind::Other,
};
Ok(Self {
conn,
sql: sql.to_string(),
kind,
})
}
pub fn run(&mut self) -> Result<String> {
self.conn.execute(&self.sql)
}
pub fn query(&self) -> Result<Rows> {
match &self.kind {
StatementKind::Select(sq) => {
let result = execute_select_rows(sq.clone(), &self.conn.db)?;
Ok(Rows {
columns: result.columns,
rows: result.rows.into_iter(),
})
}
StatementKind::Other => Err(SQLRiteError::General(
"query() only works on SELECT statements; use run() for DDL/DML".to_string(),
)),
}
}
pub fn column_names(&self) -> Option<Vec<String>> {
match &self.kind {
StatementKind::Select(_) => {
None
}
StatementKind::Other => None,
}
}
}
pub struct Rows {
columns: Vec<String>,
rows: std::vec::IntoIter<Vec<Value>>,
}
impl std::fmt::Debug for Rows {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Rows")
.field("columns", &self.columns)
.field("remaining", &self.rows.len())
.finish()
}
}
impl Rows {
pub fn columns(&self) -> &[String] {
&self.columns
}
pub fn next(&mut self) -> Result<Option<Row<'_>>> {
Ok(self.rows.next().map(|values| Row {
columns: &self.columns,
values,
}))
}
pub fn collect_all(mut self) -> Result<Vec<OwnedRow>> {
let mut out = Vec::new();
while let Some(r) = self.next()? {
out.push(r.to_owned_row());
}
Ok(out)
}
}
pub struct Row<'r> {
columns: &'r [String],
values: Vec<Value>,
}
impl<'r> Row<'r> {
pub fn get<T: FromValue>(&self, idx: usize) -> Result<T> {
let v = self.values.get(idx).ok_or_else(|| {
SQLRiteError::General(format!(
"column index {idx} out of bounds (row has {} columns)",
self.values.len()
))
})?;
T::from_value(v)
}
pub fn get_by_name<T: FromValue>(&self, name: &str) -> Result<T> {
let idx = self
.columns
.iter()
.position(|c| c == name)
.ok_or_else(|| SQLRiteError::General(format!("no column named '{name}' in row")))?;
self.get(idx)
}
pub fn columns(&self) -> &[String] {
self.columns
}
pub fn to_owned_row(&self) -> OwnedRow {
OwnedRow {
columns: self.columns.to_vec(),
values: self.values.clone(),
}
}
}
#[derive(Debug, Clone)]
pub struct OwnedRow {
pub columns: Vec<String>,
pub values: Vec<Value>,
}
impl OwnedRow {
pub fn get<T: FromValue>(&self, idx: usize) -> Result<T> {
let v = self.values.get(idx).ok_or_else(|| {
SQLRiteError::General(format!(
"column index {idx} out of bounds (row has {} columns)",
self.values.len()
))
})?;
T::from_value(v)
}
pub fn get_by_name<T: FromValue>(&self, name: &str) -> Result<T> {
let idx = self
.columns
.iter()
.position(|c| c == name)
.ok_or_else(|| SQLRiteError::General(format!("no column named '{name}' in row")))?;
self.get(idx)
}
}
pub trait FromValue: Sized {
fn from_value(v: &Value) -> Result<Self>;
}
impl FromValue for i64 {
fn from_value(v: &Value) -> Result<Self> {
match v {
Value::Integer(n) => Ok(*n),
Value::Null => Err(SQLRiteError::General(
"expected Integer, got NULL".to_string(),
)),
other => Err(SQLRiteError::General(format!(
"cannot convert {other:?} to i64"
))),
}
}
}
impl FromValue for f64 {
fn from_value(v: &Value) -> Result<Self> {
match v {
Value::Real(f) => Ok(*f),
Value::Integer(n) => Ok(*n as f64),
Value::Null => Err(SQLRiteError::General("expected Real, got NULL".to_string())),
other => Err(SQLRiteError::General(format!(
"cannot convert {other:?} to f64"
))),
}
}
}
impl FromValue for String {
fn from_value(v: &Value) -> Result<Self> {
match v {
Value::Text(s) => Ok(s.clone()),
Value::Null => Err(SQLRiteError::General("expected Text, got NULL".to_string())),
other => Err(SQLRiteError::General(format!(
"cannot convert {other:?} to String"
))),
}
}
}
impl FromValue for bool {
fn from_value(v: &Value) -> Result<Self> {
match v {
Value::Bool(b) => Ok(*b),
Value::Integer(n) => Ok(*n != 0),
Value::Null => Err(SQLRiteError::General("expected Bool, got NULL".to_string())),
other => Err(SQLRiteError::General(format!(
"cannot convert {other:?} to bool"
))),
}
}
}
impl<T: FromValue> FromValue for Option<T> {
fn from_value(v: &Value) -> Result<Self> {
match v {
Value::Null => Ok(None),
other => Ok(Some(T::from_value(other)?)),
}
}
}
impl FromValue for Value {
fn from_value(v: &Value) -> Result<Self> {
Ok(v.clone())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn tmp_path(name: &str) -> std::path::PathBuf {
let mut p = std::env::temp_dir();
let pid = std::process::id();
let nanos = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos())
.unwrap_or(0);
p.push(format!("sqlrite-conn-{pid}-{nanos}-{name}.sqlrite"));
p
}
fn cleanup(path: &std::path::Path) {
let _ = std::fs::remove_file(path);
let mut wal = path.as_os_str().to_owned();
wal.push("-wal");
let _ = std::fs::remove_file(std::path::PathBuf::from(wal));
}
#[test]
fn in_memory_roundtrip() {
let mut conn = Connection::open_in_memory().unwrap();
conn.execute("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT, age INTEGER);")
.unwrap();
conn.execute("INSERT INTO users (name, age) VALUES ('alice', 30);")
.unwrap();
conn.execute("INSERT INTO users (name, age) VALUES ('bob', 25);")
.unwrap();
let stmt = conn.prepare("SELECT id, name, age FROM users;").unwrap();
let mut rows = stmt.query().unwrap();
assert_eq!(rows.columns(), &["id", "name", "age"]);
let mut collected: Vec<(i64, String, i64)> = Vec::new();
while let Some(row) = rows.next().unwrap() {
collected.push((
row.get::<i64>(0).unwrap(),
row.get::<String>(1).unwrap(),
row.get::<i64>(2).unwrap(),
));
}
assert_eq!(collected.len(), 2);
assert!(collected.iter().any(|(_, n, a)| n == "alice" && *a == 30));
assert!(collected.iter().any(|(_, n, a)| n == "bob" && *a == 25));
}
#[test]
fn file_backed_persists_across_connections() {
let path = tmp_path("persist");
{
let mut c1 = Connection::open(&path).unwrap();
c1.execute("CREATE TABLE items (id INTEGER PRIMARY KEY, label TEXT);")
.unwrap();
c1.execute("INSERT INTO items (label) VALUES ('one');")
.unwrap();
}
{
let mut c2 = Connection::open(&path).unwrap();
let stmt = c2.prepare("SELECT label FROM items;").unwrap();
let mut rows = stmt.query().unwrap();
let first = rows.next().unwrap().expect("one row");
assert_eq!(first.get::<String>(0).unwrap(), "one");
assert!(rows.next().unwrap().is_none());
}
cleanup(&path);
}
#[test]
fn read_only_connection_rejects_writes() {
let path = tmp_path("ro_reject");
{
let mut c = Connection::open(&path).unwrap();
c.execute("CREATE TABLE t (id INTEGER PRIMARY KEY);")
.unwrap();
c.execute("INSERT INTO t (id) VALUES (1);").unwrap();
}
let mut ro = Connection::open_read_only(&path).unwrap();
assert!(ro.is_read_only());
let err = ro.execute("INSERT INTO t (id) VALUES (2);").unwrap_err();
assert!(format!("{err}").contains("read-only"));
cleanup(&path);
}
#[test]
fn transactions_work_through_connection() {
let mut conn = Connection::open_in_memory().unwrap();
conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY, x INTEGER);")
.unwrap();
conn.execute("INSERT INTO t (x) VALUES (1);").unwrap();
conn.execute("BEGIN;").unwrap();
assert!(conn.in_transaction());
conn.execute("INSERT INTO t (x) VALUES (2);").unwrap();
conn.execute("ROLLBACK;").unwrap();
assert!(!conn.in_transaction());
let stmt = conn.prepare("SELECT x FROM t;").unwrap();
let rows = stmt.query().unwrap().collect_all().unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get::<i64>(0).unwrap(), 1);
}
#[test]
fn get_by_name_works() {
let mut conn = Connection::open_in_memory().unwrap();
conn.execute("CREATE TABLE t (a INTEGER, b TEXT);").unwrap();
conn.execute("INSERT INTO t (a, b) VALUES (42, 'hello');")
.unwrap();
let stmt = conn.prepare("SELECT a, b FROM t;").unwrap();
let mut rows = stmt.query().unwrap();
let row = rows.next().unwrap().unwrap();
assert_eq!(row.get_by_name::<i64>("a").unwrap(), 42);
assert_eq!(row.get_by_name::<String>("b").unwrap(), "hello");
}
#[test]
fn null_column_maps_to_none() {
let mut conn = Connection::open_in_memory().unwrap();
conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY, note TEXT);")
.unwrap();
conn.execute("INSERT INTO t (id) VALUES (1);").unwrap();
let stmt = conn.prepare("SELECT id, note FROM t;").unwrap();
let mut rows = stmt.query().unwrap();
let row = rows.next().unwrap().unwrap();
assert_eq!(row.get::<i64>(0).unwrap(), 1);
assert_eq!(row.get::<Option<String>>(1).unwrap(), None);
}
#[test]
fn prepare_rejects_multiple_statements() {
let mut conn = Connection::open_in_memory().unwrap();
let err = conn.prepare("SELECT 1; SELECT 2;").unwrap_err();
assert!(format!("{err}").contains("single statement"));
}
#[test]
fn query_on_non_select_errors() {
let mut conn = Connection::open_in_memory().unwrap();
conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY);")
.unwrap();
let stmt = conn.prepare("INSERT INTO t VALUES (1);").unwrap();
let err = stmt.query().unwrap_err();
assert!(format!("{err}").contains("SELECT"));
}
#[test]
fn index_out_of_bounds_errors_cleanly() {
let mut conn = Connection::open_in_memory().unwrap();
conn.execute("CREATE TABLE t (a INTEGER PRIMARY KEY);")
.unwrap();
conn.execute("INSERT INTO t (a) VALUES (1);").unwrap();
let stmt = conn.prepare("SELECT a FROM t;").unwrap();
let mut rows = stmt.query().unwrap();
let row = rows.next().unwrap().unwrap();
let err = row.get::<i64>(99).unwrap_err();
assert!(format!("{err}").contains("out of bounds"));
}
}