use std::collections::VecDeque;
use std::path::Path;
use std::sync::Arc;
use sqlparser::ast::Statement as AstStatement;
use sqlparser::dialect::SQLiteDialect;
use sqlparser::parser::Parser;
use crate::error::{Result, SQLRiteError};
use crate::sql::db::database::Database;
use crate::sql::db::table::Value;
use crate::sql::executor::execute_select_rows;
use crate::sql::pager::{AccessMode, open_database_with_mode, save_database};
use crate::sql::params::{rewrite_placeholders, substitute_params};
use crate::sql::parser::select::SelectQuery;
use crate::sql::process_ast_with_render;
const DEFAULT_PREP_CACHE_CAP: usize = 16;
pub struct Connection {
db: Database,
prep_cache: VecDeque<(String, Arc<CachedPlan>)>,
prep_cache_cap: usize,
}
impl Connection {
pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
let path = path.as_ref();
let db_name = path
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("db")
.to_string();
let db = if path.exists() {
open_database_with_mode(path, db_name, AccessMode::ReadWrite)?
} else {
let mut fresh = Database::new(db_name);
fresh.source_path = Some(path.to_path_buf());
save_database(&mut fresh, path)?;
fresh
};
Ok(Self::wrap(db))
}
pub fn open_read_only<P: AsRef<Path>>(path: P) -> Result<Self> {
let path = path.as_ref();
let db_name = path
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("db")
.to_string();
let db = open_database_with_mode(path, db_name, AccessMode::ReadOnly)?;
Ok(Self::wrap(db))
}
pub fn open_in_memory() -> Result<Self> {
Ok(Self::wrap(Database::new("memdb".to_string())))
}
fn wrap(db: Database) -> Self {
Self {
db,
prep_cache: VecDeque::new(),
prep_cache_cap: DEFAULT_PREP_CACHE_CAP,
}
}
pub fn execute(&mut self, sql: &str) -> Result<String> {
crate::sql::process_command(sql, &mut self.db)
}
pub fn prepare<'c>(&'c mut self, sql: &str) -> Result<Statement<'c>> {
let plan = Arc::new(CachedPlan::compile(sql)?);
Ok(Statement { conn: self, plan })
}
pub fn prepare_cached<'c>(&'c mut self, sql: &str) -> Result<Statement<'c>> {
let plan = if let Some(pos) = self.prep_cache.iter().position(|(k, _)| k == sql) {
let (k, v) = self.prep_cache.remove(pos).unwrap();
self.prep_cache.push_back((k, Arc::clone(&v)));
v
} else {
let plan = Arc::new(CachedPlan::compile(sql)?);
self.prep_cache
.push_back((sql.to_string(), Arc::clone(&plan)));
while self.prep_cache.len() > self.prep_cache_cap {
self.prep_cache.pop_front();
}
plan
};
Ok(Statement { conn: self, plan })
}
pub fn set_prepared_cache_capacity(&mut self, cap: usize) {
self.prep_cache_cap = cap;
while self.prep_cache.len() > cap {
self.prep_cache.pop_front();
}
}
pub fn prepared_cache_len(&self) -> usize {
self.prep_cache.len()
}
pub fn in_transaction(&self) -> bool {
self.db.in_transaction()
}
pub fn auto_vacuum_threshold(&self) -> Option<f32> {
self.db.auto_vacuum_threshold()
}
pub fn set_auto_vacuum_threshold(&mut self, threshold: Option<f32>) -> Result<()> {
self.db.set_auto_vacuum_threshold(threshold)
}
pub fn is_read_only(&self) -> bool {
self.db.is_read_only()
}
#[doc(hidden)]
pub fn database(&self) -> &Database {
&self.db
}
#[doc(hidden)]
pub fn database_mut(&mut self) -> &mut Database {
&mut self.db
}
}
impl std::fmt::Debug for Connection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Connection")
.field("in_transaction", &self.db.in_transaction())
.field("read_only", &self.db.is_read_only())
.field("tables", &self.db.tables.len())
.field("prep_cache_len", &self.prep_cache.len())
.finish()
}
}
#[derive(Debug)]
struct CachedPlan {
#[allow(dead_code)]
sql: String,
ast: AstStatement,
param_count: usize,
select: Option<SelectQuery>,
}
impl CachedPlan {
fn compile(sql: &str) -> Result<Self> {
let dialect = SQLiteDialect {};
let mut ast = Parser::parse_sql(&dialect, sql).map_err(SQLRiteError::from)?;
let Some(mut stmt) = ast.pop() else {
return Err(SQLRiteError::General("no statement to prepare".to_string()));
};
if !ast.is_empty() {
return Err(SQLRiteError::General(
"prepare() accepts a single statement; found more than one".to_string(),
));
}
let param_count = rewrite_placeholders(&mut stmt);
let select = match &stmt {
AstStatement::Query(_) => Some(SelectQuery::new(&stmt)?),
_ => None,
};
Ok(Self {
sql: sql.to_string(),
ast: stmt,
param_count,
select,
})
}
}
pub struct Statement<'c> {
conn: &'c mut Connection,
plan: Arc<CachedPlan>,
}
impl std::fmt::Debug for Statement<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Statement")
.field("sql", &self.plan.sql)
.field("param_count", &self.plan.param_count)
.field(
"kind",
&match self.plan.select {
Some(_) => "Select",
None => "Other",
},
)
.finish()
}
}
impl<'c> Statement<'c> {
pub fn parameter_count(&self) -> usize {
self.plan.param_count
}
pub fn run(&mut self) -> Result<String> {
if self.plan.param_count > 0 {
return Err(SQLRiteError::General(format!(
"statement has {} `?` placeholder(s); call execute_with_params()",
self.plan.param_count
)));
}
let ast = self.plan.ast.clone();
process_ast_with_render(ast, &mut self.conn.db).map(|o| o.status)
}
pub fn execute_with_params(&mut self, params: &[Value]) -> Result<String> {
self.check_arity(params)?;
let mut ast = self.plan.ast.clone();
if !params.is_empty() {
substitute_params(&mut ast, params)?;
}
process_ast_with_render(ast, &mut self.conn.db).map(|o| o.status)
}
pub fn query(&self) -> Result<Rows> {
if self.plan.param_count > 0 {
return Err(SQLRiteError::General(format!(
"statement has {} `?` placeholder(s); call query_with_params()",
self.plan.param_count
)));
}
let Some(sq) = self.plan.select.as_ref() else {
return Err(SQLRiteError::General(
"query() only works on SELECT statements; use run() for DDL/DML".to_string(),
));
};
let result = execute_select_rows(sq.clone(), &self.conn.db)?;
Ok(Rows {
columns: result.columns,
rows: result.rows.into_iter(),
})
}
pub fn query_with_params(&self, params: &[Value]) -> Result<Rows> {
self.check_arity(params)?;
if self.plan.select.is_none() {
return Err(SQLRiteError::General(
"query_with_params() only works on SELECT statements; use execute_with_params() \
for DDL/DML"
.to_string(),
));
}
let mut ast = self.plan.ast.clone();
if !params.is_empty() {
substitute_params(&mut ast, params)?;
}
let sq = SelectQuery::new(&ast)?;
let result = execute_select_rows(sq, &self.conn.db)?;
Ok(Rows {
columns: result.columns,
rows: result.rows.into_iter(),
})
}
fn check_arity(&self, params: &[Value]) -> Result<()> {
if params.len() != self.plan.param_count {
return Err(SQLRiteError::General(format!(
"expected {} parameter{}, got {}",
self.plan.param_count,
if self.plan.param_count == 1 { "" } else { "s" },
params.len()
)));
}
Ok(())
}
pub fn column_names(&self) -> Option<Vec<String>> {
match &self.plan.select {
Some(_) => {
None
}
None => None,
}
}
}
pub struct Rows {
columns: Vec<String>,
rows: std::vec::IntoIter<Vec<Value>>,
}
impl std::fmt::Debug for Rows {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Rows")
.field("columns", &self.columns)
.field("remaining", &self.rows.len())
.finish()
}
}
impl Rows {
pub fn columns(&self) -> &[String] {
&self.columns
}
pub fn next(&mut self) -> Result<Option<Row<'_>>> {
Ok(self.rows.next().map(|values| Row {
columns: &self.columns,
values,
}))
}
pub fn collect_all(mut self) -> Result<Vec<OwnedRow>> {
let mut out = Vec::new();
while let Some(r) = self.next()? {
out.push(r.to_owned_row());
}
Ok(out)
}
}
pub struct Row<'r> {
columns: &'r [String],
values: Vec<Value>,
}
impl<'r> Row<'r> {
pub fn get<T: FromValue>(&self, idx: usize) -> Result<T> {
let v = self.values.get(idx).ok_or_else(|| {
SQLRiteError::General(format!(
"column index {idx} out of bounds (row has {} columns)",
self.values.len()
))
})?;
T::from_value(v)
}
pub fn get_by_name<T: FromValue>(&self, name: &str) -> Result<T> {
let idx = self
.columns
.iter()
.position(|c| c == name)
.ok_or_else(|| SQLRiteError::General(format!("no column named '{name}' in row")))?;
self.get(idx)
}
pub fn columns(&self) -> &[String] {
self.columns
}
pub fn to_owned_row(&self) -> OwnedRow {
OwnedRow {
columns: self.columns.to_vec(),
values: self.values.clone(),
}
}
}
#[derive(Debug, Clone)]
pub struct OwnedRow {
pub columns: Vec<String>,
pub values: Vec<Value>,
}
impl OwnedRow {
pub fn get<T: FromValue>(&self, idx: usize) -> Result<T> {
let v = self.values.get(idx).ok_or_else(|| {
SQLRiteError::General(format!(
"column index {idx} out of bounds (row has {} columns)",
self.values.len()
))
})?;
T::from_value(v)
}
pub fn get_by_name<T: FromValue>(&self, name: &str) -> Result<T> {
let idx = self
.columns
.iter()
.position(|c| c == name)
.ok_or_else(|| SQLRiteError::General(format!("no column named '{name}' in row")))?;
self.get(idx)
}
}
pub trait FromValue: Sized {
fn from_value(v: &Value) -> Result<Self>;
}
impl FromValue for i64 {
fn from_value(v: &Value) -> Result<Self> {
match v {
Value::Integer(n) => Ok(*n),
Value::Null => Err(SQLRiteError::General(
"expected Integer, got NULL".to_string(),
)),
other => Err(SQLRiteError::General(format!(
"cannot convert {other:?} to i64"
))),
}
}
}
impl FromValue for f64 {
fn from_value(v: &Value) -> Result<Self> {
match v {
Value::Real(f) => Ok(*f),
Value::Integer(n) => Ok(*n as f64),
Value::Null => Err(SQLRiteError::General("expected Real, got NULL".to_string())),
other => Err(SQLRiteError::General(format!(
"cannot convert {other:?} to f64"
))),
}
}
}
impl FromValue for String {
fn from_value(v: &Value) -> Result<Self> {
match v {
Value::Text(s) => Ok(s.clone()),
Value::Null => Err(SQLRiteError::General("expected Text, got NULL".to_string())),
other => Err(SQLRiteError::General(format!(
"cannot convert {other:?} to String"
))),
}
}
}
impl FromValue for bool {
fn from_value(v: &Value) -> Result<Self> {
match v {
Value::Bool(b) => Ok(*b),
Value::Integer(n) => Ok(*n != 0),
Value::Null => Err(SQLRiteError::General("expected Bool, got NULL".to_string())),
other => Err(SQLRiteError::General(format!(
"cannot convert {other:?} to bool"
))),
}
}
}
impl<T: FromValue> FromValue for Option<T> {
fn from_value(v: &Value) -> Result<Self> {
match v {
Value::Null => Ok(None),
other => Ok(Some(T::from_value(other)?)),
}
}
}
impl FromValue for Value {
fn from_value(v: &Value) -> Result<Self> {
Ok(v.clone())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn tmp_path(name: &str) -> std::path::PathBuf {
let mut p = std::env::temp_dir();
let pid = std::process::id();
let nanos = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos())
.unwrap_or(0);
p.push(format!("sqlrite-conn-{pid}-{nanos}-{name}.sqlrite"));
p
}
fn cleanup(path: &std::path::Path) {
let _ = std::fs::remove_file(path);
let mut wal = path.as_os_str().to_owned();
wal.push("-wal");
let _ = std::fs::remove_file(std::path::PathBuf::from(wal));
}
#[test]
fn in_memory_roundtrip() {
let mut conn = Connection::open_in_memory().unwrap();
conn.execute("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT, age INTEGER);")
.unwrap();
conn.execute("INSERT INTO users (name, age) VALUES ('alice', 30);")
.unwrap();
conn.execute("INSERT INTO users (name, age) VALUES ('bob', 25);")
.unwrap();
let stmt = conn.prepare("SELECT id, name, age FROM users;").unwrap();
let mut rows = stmt.query().unwrap();
assert_eq!(rows.columns(), &["id", "name", "age"]);
let mut collected: Vec<(i64, String, i64)> = Vec::new();
while let Some(row) = rows.next().unwrap() {
collected.push((
row.get::<i64>(0).unwrap(),
row.get::<String>(1).unwrap(),
row.get::<i64>(2).unwrap(),
));
}
assert_eq!(collected.len(), 2);
assert!(collected.iter().any(|(_, n, a)| n == "alice" && *a == 30));
assert!(collected.iter().any(|(_, n, a)| n == "bob" && *a == 25));
}
#[test]
fn file_backed_persists_across_connections() {
let path = tmp_path("persist");
{
let mut c1 = Connection::open(&path).unwrap();
c1.execute("CREATE TABLE items (id INTEGER PRIMARY KEY, label TEXT);")
.unwrap();
c1.execute("INSERT INTO items (label) VALUES ('one');")
.unwrap();
}
{
let mut c2 = Connection::open(&path).unwrap();
let stmt = c2.prepare("SELECT label FROM items;").unwrap();
let mut rows = stmt.query().unwrap();
let first = rows.next().unwrap().expect("one row");
assert_eq!(first.get::<String>(0).unwrap(), "one");
assert!(rows.next().unwrap().is_none());
}
cleanup(&path);
}
#[test]
fn read_only_connection_rejects_writes() {
let path = tmp_path("ro_reject");
{
let mut c = Connection::open(&path).unwrap();
c.execute("CREATE TABLE t (id INTEGER PRIMARY KEY);")
.unwrap();
c.execute("INSERT INTO t (id) VALUES (1);").unwrap();
}
let mut ro = Connection::open_read_only(&path).unwrap();
assert!(ro.is_read_only());
let err = ro.execute("INSERT INTO t (id) VALUES (2);").unwrap_err();
assert!(format!("{err}").contains("read-only"));
cleanup(&path);
}
#[test]
fn transactions_work_through_connection() {
let mut conn = Connection::open_in_memory().unwrap();
conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY, x INTEGER);")
.unwrap();
conn.execute("INSERT INTO t (x) VALUES (1);").unwrap();
conn.execute("BEGIN;").unwrap();
assert!(conn.in_transaction());
conn.execute("INSERT INTO t (x) VALUES (2);").unwrap();
conn.execute("ROLLBACK;").unwrap();
assert!(!conn.in_transaction());
let stmt = conn.prepare("SELECT x FROM t;").unwrap();
let rows = stmt.query().unwrap().collect_all().unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get::<i64>(0).unwrap(), 1);
}
#[test]
fn get_by_name_works() {
let mut conn = Connection::open_in_memory().unwrap();
conn.execute("CREATE TABLE t (a INTEGER, b TEXT);").unwrap();
conn.execute("INSERT INTO t (a, b) VALUES (42, 'hello');")
.unwrap();
let stmt = conn.prepare("SELECT a, b FROM t;").unwrap();
let mut rows = stmt.query().unwrap();
let row = rows.next().unwrap().unwrap();
assert_eq!(row.get_by_name::<i64>("a").unwrap(), 42);
assert_eq!(row.get_by_name::<String>("b").unwrap(), "hello");
}
#[test]
fn null_column_maps_to_none() {
let mut conn = Connection::open_in_memory().unwrap();
conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY, note TEXT);")
.unwrap();
conn.execute("INSERT INTO t (id) VALUES (1);").unwrap();
let stmt = conn.prepare("SELECT id, note FROM t;").unwrap();
let mut rows = stmt.query().unwrap();
let row = rows.next().unwrap().unwrap();
assert_eq!(row.get::<i64>(0).unwrap(), 1);
assert_eq!(row.get::<Option<String>>(1).unwrap(), None);
}
#[test]
fn prepare_rejects_multiple_statements() {
let mut conn = Connection::open_in_memory().unwrap();
let err = conn.prepare("SELECT 1; SELECT 2;").unwrap_err();
assert!(format!("{err}").contains("single statement"));
}
#[test]
fn query_on_non_select_errors() {
let mut conn = Connection::open_in_memory().unwrap();
conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY);")
.unwrap();
let stmt = conn.prepare("INSERT INTO t VALUES (1);").unwrap();
let err = stmt.query().unwrap_err();
assert!(format!("{err}").contains("SELECT"));
}
#[test]
fn auto_vacuum_threshold_default_and_setter() {
let mut conn = Connection::open_in_memory().unwrap();
assert_eq!(
conn.auto_vacuum_threshold(),
Some(0.25),
"fresh connection should ship with the SQLite-parity default"
);
conn.set_auto_vacuum_threshold(None).unwrap();
assert_eq!(conn.auto_vacuum_threshold(), None);
conn.set_auto_vacuum_threshold(Some(0.5)).unwrap();
assert_eq!(conn.auto_vacuum_threshold(), Some(0.5));
let err = conn.set_auto_vacuum_threshold(Some(1.5)).unwrap_err();
assert!(
format!("{err}").contains("auto_vacuum_threshold"),
"expected typed range error, got: {err}"
);
assert_eq!(
conn.auto_vacuum_threshold(),
Some(0.5),
"rejected setter call must not mutate the threshold"
);
}
#[test]
fn index_out_of_bounds_errors_cleanly() {
let mut conn = Connection::open_in_memory().unwrap();
conn.execute("CREATE TABLE t (a INTEGER PRIMARY KEY);")
.unwrap();
conn.execute("INSERT INTO t (a) VALUES (1);").unwrap();
let stmt = conn.prepare("SELECT a FROM t;").unwrap();
let mut rows = stmt.query().unwrap();
let row = rows.next().unwrap().unwrap();
let err = row.get::<i64>(99).unwrap_err();
assert!(format!("{err}").contains("out of bounds"));
}
#[test]
fn parameter_count_reflects_question_marks() {
let mut conn = Connection::open_in_memory().unwrap();
conn.execute("CREATE TABLE t (a INTEGER, b TEXT);").unwrap();
let stmt = conn.prepare("SELECT a, b FROM t WHERE a = ?").unwrap();
assert_eq!(stmt.parameter_count(), 1);
let stmt = conn
.prepare("SELECT a, b FROM t WHERE a = ? AND b = ?")
.unwrap();
assert_eq!(stmt.parameter_count(), 2);
let stmt = conn.prepare("SELECT a FROM t").unwrap();
assert_eq!(stmt.parameter_count(), 0);
}
#[test]
fn query_with_params_binds_scalars() {
let mut conn = Connection::open_in_memory().unwrap();
conn.execute("CREATE TABLE t (a INTEGER PRIMARY KEY, b TEXT);")
.unwrap();
conn.execute("INSERT INTO t (a, b) VALUES (1, 'alice');")
.unwrap();
conn.execute("INSERT INTO t (a, b) VALUES (2, 'bob');")
.unwrap();
conn.execute("INSERT INTO t (a, b) VALUES (3, 'carol');")
.unwrap();
let stmt = conn.prepare("SELECT b FROM t WHERE a = ?").unwrap();
let rows = stmt
.query_with_params(&[Value::Integer(2)])
.unwrap()
.collect_all()
.unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get::<String>(0).unwrap(), "bob");
}
#[test]
fn execute_with_params_binds_insert_values() {
let mut conn = Connection::open_in_memory().unwrap();
conn.execute("CREATE TABLE t (a INTEGER, b TEXT);").unwrap();
let mut stmt = conn.prepare("INSERT INTO t (a, b) VALUES (?, ?)").unwrap();
stmt.execute_with_params(&[Value::Integer(7), Value::Text("hi".into())])
.unwrap();
stmt.execute_with_params(&[Value::Integer(8), Value::Text("yo".into())])
.unwrap();
let stmt = conn.prepare("SELECT a, b FROM t").unwrap();
let rows = stmt.query().unwrap().collect_all().unwrap();
assert_eq!(rows.len(), 2);
assert!(
rows.iter()
.any(|r| r.get::<i64>(0).unwrap() == 7 && r.get::<String>(1).unwrap() == "hi")
);
assert!(
rows.iter()
.any(|r| r.get::<i64>(0).unwrap() == 8 && r.get::<String>(1).unwrap() == "yo")
);
}
#[test]
fn arity_mismatch_returns_clean_error() {
let mut conn = Connection::open_in_memory().unwrap();
conn.execute("CREATE TABLE t (a INTEGER, b TEXT);").unwrap();
let stmt = conn
.prepare("SELECT * FROM t WHERE a = ? AND b = ?")
.unwrap();
let err = stmt.query_with_params(&[Value::Integer(1)]).unwrap_err();
assert!(format!("{err}").contains("expected 2 parameter"));
}
#[test]
fn run_and_query_reject_when_placeholders_present() {
let mut conn = Connection::open_in_memory().unwrap();
conn.execute("CREATE TABLE t (a INTEGER);").unwrap();
let mut stmt_select = conn.prepare("SELECT a FROM t WHERE a = ?").unwrap();
let err = stmt_select.query().unwrap_err();
assert!(format!("{err}").contains("query_with_params"));
let err = stmt_select.run().unwrap_err();
assert!(format!("{err}").contains("execute_with_params"));
}
#[test]
fn null_param_compares_against_null() {
let mut conn = Connection::open_in_memory().unwrap();
conn.execute("CREATE TABLE t (a INTEGER);").unwrap();
conn.execute("INSERT INTO t (a) VALUES (1);").unwrap();
let stmt = conn.prepare("SELECT a FROM t WHERE a = ?").unwrap();
let rows = stmt
.query_with_params(&[Value::Null])
.unwrap()
.collect_all()
.unwrap();
assert_eq!(rows.len(), 0);
}
#[test]
fn vector_param_substitutes_through_select() {
let mut conn = Connection::open_in_memory().unwrap();
conn.execute("CREATE TABLE v (id INTEGER PRIMARY KEY, e VECTOR(3));")
.unwrap();
conn.execute("INSERT INTO v (id, e) VALUES (1, [1.0, 0.0, 0.0]);")
.unwrap();
conn.execute("INSERT INTO v (id, e) VALUES (2, [0.0, 1.0, 0.0]);")
.unwrap();
conn.execute("INSERT INTO v (id, e) VALUES (3, [0.0, 0.0, 1.0]);")
.unwrap();
let stmt = conn
.prepare("SELECT id FROM v ORDER BY vec_distance_l2(e, ?) ASC LIMIT 1")
.unwrap();
let rows = stmt
.query_with_params(&[Value::Vector(vec![1.0, 0.0, 0.0])])
.unwrap()
.collect_all()
.unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get::<i64>(0).unwrap(), 1);
}
#[test]
fn prepare_cached_reuses_plans() {
let mut conn = Connection::open_in_memory().unwrap();
conn.execute("CREATE TABLE t (a INTEGER);").unwrap();
for n in 1..=3 {
conn.execute(&format!("INSERT INTO t (a) VALUES ({n});"))
.unwrap();
}
let _ = conn.prepare_cached("SELECT a FROM t WHERE a = ?").unwrap();
let _ = conn.prepare_cached("SELECT a FROM t WHERE a = ?").unwrap();
assert_eq!(conn.prepared_cache_len(), 1);
let _ = conn.prepare_cached("SELECT a FROM t").unwrap();
assert_eq!(conn.prepared_cache_len(), 2);
}
#[test]
fn prepare_cached_evicts_when_over_capacity() {
let mut conn = Connection::open_in_memory().unwrap();
conn.execute("CREATE TABLE t (a INTEGER);").unwrap();
conn.set_prepared_cache_capacity(2);
let _ = conn.prepare_cached("SELECT a FROM t").unwrap();
let _ = conn.prepare_cached("SELECT a FROM t WHERE a = ?").unwrap();
assert_eq!(conn.prepared_cache_len(), 2);
let _ = conn.prepare_cached("SELECT a FROM t WHERE a > ?").unwrap();
assert_eq!(conn.prepared_cache_len(), 2);
}
#[test]
fn vector_bind_through_hnsw_optimizer() {
let mut conn = Connection::open_in_memory().unwrap();
conn.execute("CREATE TABLE v (id INTEGER PRIMARY KEY, e VECTOR(4));")
.unwrap();
let corpus: [(i64, [f32; 4]); 5] = [
(1, [1.0, 0.0, 0.0, 0.0]),
(2, [0.0, 1.0, 0.0, 0.0]),
(3, [0.0, 0.0, 1.0, 0.0]),
(4, [0.0, 0.0, 0.0, 1.0]),
(5, [0.5, 0.5, 0.5, 0.5]),
];
for (id, vec) in corpus {
conn.execute(&format!(
"INSERT INTO v (id, e) VALUES ({id}, [{}, {}, {}, {}]);",
vec[0], vec[1], vec[2], vec[3]
))
.unwrap();
}
conn.execute("CREATE INDEX v_hnsw ON v USING hnsw (e);")
.unwrap();
let stmt = conn
.prepare("SELECT id FROM v ORDER BY vec_distance_l2(e, ?) ASC LIMIT 1")
.unwrap();
let rows = stmt
.query_with_params(&[Value::Vector(vec![0.0, 0.0, 1.0, 0.0])])
.unwrap()
.collect_all()
.unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get::<i64>(0).unwrap(), 3);
let rows = stmt
.query_with_params(&[Value::Vector(vec![1.0, 0.0, 0.0, 0.0])])
.unwrap()
.collect_all()
.unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get::<i64>(0).unwrap(), 1);
}
#[test]
fn prepare_cached_executes_the_same_as_prepare() {
let mut conn = Connection::open_in_memory().unwrap();
conn.execute("CREATE TABLE t (a INTEGER PRIMARY KEY, b TEXT);")
.unwrap();
let mut ins = conn
.prepare_cached("INSERT INTO t (a, b) VALUES (?, ?)")
.unwrap();
ins.execute_with_params(&[Value::Integer(1), Value::Text("alpha".into())])
.unwrap();
ins.execute_with_params(&[Value::Integer(2), Value::Text("beta".into())])
.unwrap();
let stmt = conn.prepare_cached("SELECT b FROM t WHERE a = ?").unwrap();
let rows = stmt
.query_with_params(&[Value::Integer(2)])
.unwrap()
.collect_all()
.unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get::<String>(0).unwrap(), "beta");
}
}