1use std::collections::VecDeque;
54use std::path::Path;
55use std::sync::Arc;
56
57use sqlparser::ast::Statement as AstStatement;
58use sqlparser::dialect::SQLiteDialect;
59use sqlparser::parser::Parser;
60
61use crate::error::{Result, SQLRiteError};
62use crate::sql::db::database::Database;
63use crate::sql::db::table::Value;
64use crate::sql::executor::execute_select_rows;
65use crate::sql::pager::{AccessMode, open_database_with_mode, save_database};
66use crate::sql::params::{rewrite_placeholders, substitute_params};
67use crate::sql::parser::select::SelectQuery;
68use crate::sql::process_ast_with_render;
69
70const DEFAULT_PREP_CACHE_CAP: usize = 16;
73
74pub struct Connection {
94 db: Database,
95 prep_cache: VecDeque<(String, Arc<CachedPlan>)>,
101 prep_cache_cap: usize,
102}
103
104impl Connection {
105 pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
112 let path = path.as_ref();
113 let db_name = path
114 .file_stem()
115 .and_then(|s| s.to_str())
116 .unwrap_or("db")
117 .to_string();
118 let db = if path.exists() {
119 open_database_with_mode(path, db_name, AccessMode::ReadWrite)?
120 } else {
121 let mut fresh = Database::new(db_name);
127 fresh.source_path = Some(path.to_path_buf());
128 save_database(&mut fresh, path)?;
129 fresh
130 };
131 Ok(Self::wrap(db))
132 }
133
134 pub fn open_read_only<P: AsRef<Path>>(path: P) -> Result<Self> {
140 let path = path.as_ref();
141 let db_name = path
142 .file_stem()
143 .and_then(|s| s.to_str())
144 .unwrap_or("db")
145 .to_string();
146 let db = open_database_with_mode(path, db_name, AccessMode::ReadOnly)?;
147 Ok(Self::wrap(db))
148 }
149
150 pub fn open_in_memory() -> Result<Self> {
154 Ok(Self::wrap(Database::new("memdb".to_string())))
155 }
156
157 fn wrap(db: Database) -> Self {
158 Self {
159 db,
160 prep_cache: VecDeque::new(),
161 prep_cache_cap: DEFAULT_PREP_CACHE_CAP,
162 }
163 }
164
165 pub fn execute(&mut self, sql: &str) -> Result<String> {
175 crate::sql::process_command(sql, &mut self.db)
176 }
177
178 pub fn prepare<'c>(&'c mut self, sql: &str) -> Result<Statement<'c>> {
188 let plan = Arc::new(CachedPlan::compile(sql)?);
189 Ok(Statement { conn: self, plan })
190 }
191
192 pub fn prepare_cached<'c>(&'c mut self, sql: &str) -> Result<Statement<'c>> {
201 let plan = if let Some(pos) = self.prep_cache.iter().position(|(k, _)| k == sql) {
204 let (k, v) = self.prep_cache.remove(pos).unwrap();
205 self.prep_cache.push_back((k, Arc::clone(&v)));
206 v
207 } else {
208 let plan = Arc::new(CachedPlan::compile(sql)?);
209 self.prep_cache
210 .push_back((sql.to_string(), Arc::clone(&plan)));
211 while self.prep_cache.len() > self.prep_cache_cap {
212 self.prep_cache.pop_front();
213 }
214 plan
215 };
216 Ok(Statement { conn: self, plan })
217 }
218
219 pub fn set_prepared_cache_capacity(&mut self, cap: usize) {
225 self.prep_cache_cap = cap;
226 while self.prep_cache.len() > cap {
227 self.prep_cache.pop_front();
228 }
229 }
230
231 pub fn prepared_cache_len(&self) -> usize {
235 self.prep_cache.len()
236 }
237
238 pub fn in_transaction(&self) -> bool {
241 self.db.in_transaction()
242 }
243
244 pub fn auto_vacuum_threshold(&self) -> Option<f32> {
251 self.db.auto_vacuum_threshold()
252 }
253
254 pub fn set_auto_vacuum_threshold(&mut self, threshold: Option<f32>) -> Result<()> {
265 self.db.set_auto_vacuum_threshold(threshold)
266 }
267
268 pub fn is_read_only(&self) -> bool {
271 self.db.is_read_only()
272 }
273
274 #[doc(hidden)]
278 pub fn database(&self) -> &Database {
279 &self.db
280 }
281
282 #[doc(hidden)]
283 pub fn database_mut(&mut self) -> &mut Database {
284 &mut self.db
285 }
286}
287
288impl std::fmt::Debug for Connection {
289 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
290 f.debug_struct("Connection")
291 .field("in_transaction", &self.db.in_transaction())
292 .field("read_only", &self.db.is_read_only())
293 .field("tables", &self.db.tables.len())
294 .field("prep_cache_len", &self.prep_cache.len())
295 .finish()
296 }
297}
298
299#[derive(Debug)]
304struct CachedPlan {
305 #[allow(dead_code)]
307 sql: String,
308 ast: AstStatement,
311 param_count: usize,
315 select: Option<SelectQuery>,
319}
320
321impl CachedPlan {
322 fn compile(sql: &str) -> Result<Self> {
323 let dialect = SQLiteDialect {};
324 let mut ast = Parser::parse_sql(&dialect, sql).map_err(SQLRiteError::from)?;
325 let Some(mut stmt) = ast.pop() else {
326 return Err(SQLRiteError::General("no statement to prepare".to_string()));
327 };
328 if !ast.is_empty() {
329 return Err(SQLRiteError::General(
330 "prepare() accepts a single statement; found more than one".to_string(),
331 ));
332 }
333 let param_count = rewrite_placeholders(&mut stmt);
334 let select = match &stmt {
335 AstStatement::Query(_) => Some(SelectQuery::new(&stmt)?),
336 _ => None,
337 };
338 Ok(Self {
339 sql: sql.to_string(),
340 ast: stmt,
341 param_count,
342 select,
343 })
344 }
345}
346
347pub struct Statement<'c> {
354 conn: &'c mut Connection,
355 plan: Arc<CachedPlan>,
356}
357
358impl std::fmt::Debug for Statement<'_> {
359 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
360 f.debug_struct("Statement")
361 .field("sql", &self.plan.sql)
362 .field("param_count", &self.plan.param_count)
363 .field(
364 "kind",
365 &match self.plan.select {
366 Some(_) => "Select",
367 None => "Other",
368 },
369 )
370 .finish()
371 }
372}
373
374impl<'c> Statement<'c> {
375 pub fn parameter_count(&self) -> usize {
380 self.plan.param_count
381 }
382
383 pub fn run(&mut self) -> Result<String> {
391 if self.plan.param_count > 0 {
392 return Err(SQLRiteError::General(format!(
393 "statement has {} `?` placeholder(s); call execute_with_params()",
394 self.plan.param_count
395 )));
396 }
397 let ast = self.plan.ast.clone();
398 process_ast_with_render(ast, &mut self.conn.db).map(|o| o.status)
399 }
400
401 pub fn execute_with_params(&mut self, params: &[Value]) -> Result<String> {
409 self.check_arity(params)?;
410 let mut ast = self.plan.ast.clone();
411 if !params.is_empty() {
412 substitute_params(&mut ast, params)?;
413 }
414 process_ast_with_render(ast, &mut self.conn.db).map(|o| o.status)
415 }
416
417 pub fn query(&self) -> Result<Rows> {
425 if self.plan.param_count > 0 {
426 return Err(SQLRiteError::General(format!(
427 "statement has {} `?` placeholder(s); call query_with_params()",
428 self.plan.param_count
429 )));
430 }
431 let Some(sq) = self.plan.select.as_ref() else {
432 return Err(SQLRiteError::General(
433 "query() only works on SELECT statements; use run() for DDL/DML".to_string(),
434 ));
435 };
436 let result = execute_select_rows(sq.clone(), &self.conn.db)?;
437 Ok(Rows {
438 columns: result.columns,
439 rows: result.rows.into_iter(),
440 })
441 }
442
443 pub fn query_with_params(&self, params: &[Value]) -> Result<Rows> {
452 self.check_arity(params)?;
453 if self.plan.select.is_none() {
454 return Err(SQLRiteError::General(
455 "query_with_params() only works on SELECT statements; use execute_with_params() \
456 for DDL/DML"
457 .to_string(),
458 ));
459 }
460 let mut ast = self.plan.ast.clone();
465 if !params.is_empty() {
466 substitute_params(&mut ast, params)?;
467 }
468 let sq = SelectQuery::new(&ast)?;
469 let result = execute_select_rows(sq, &self.conn.db)?;
470 Ok(Rows {
471 columns: result.columns,
472 rows: result.rows.into_iter(),
473 })
474 }
475
476 fn check_arity(&self, params: &[Value]) -> Result<()> {
477 if params.len() != self.plan.param_count {
478 return Err(SQLRiteError::General(format!(
479 "expected {} parameter{}, got {}",
480 self.plan.param_count,
481 if self.plan.param_count == 1 { "" } else { "s" },
482 params.len()
483 )));
484 }
485 Ok(())
486 }
487
488 pub fn column_names(&self) -> Option<Vec<String>> {
491 match &self.plan.select {
492 Some(_) => {
493 None
498 }
499 None => None,
500 }
501 }
502}
503
504pub struct Rows {
513 columns: Vec<String>,
514 rows: std::vec::IntoIter<Vec<Value>>,
515}
516
517impl std::fmt::Debug for Rows {
518 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
519 f.debug_struct("Rows")
520 .field("columns", &self.columns)
521 .field("remaining", &self.rows.len())
522 .finish()
523 }
524}
525
526impl Rows {
527 pub fn columns(&self) -> &[String] {
529 &self.columns
530 }
531
532 pub fn next(&mut self) -> Result<Option<Row<'_>>> {
537 Ok(self.rows.next().map(|values| Row {
538 columns: &self.columns,
539 values,
540 }))
541 }
542
543 pub fn collect_all(mut self) -> Result<Vec<OwnedRow>> {
547 let mut out = Vec::new();
548 while let Some(r) = self.next()? {
549 out.push(r.to_owned_row());
550 }
551 Ok(out)
552 }
553}
554
555pub struct Row<'r> {
559 columns: &'r [String],
560 values: Vec<Value>,
561}
562
563impl<'r> Row<'r> {
564 pub fn get<T: FromValue>(&self, idx: usize) -> Result<T> {
567 let v = self.values.get(idx).ok_or_else(|| {
568 SQLRiteError::General(format!(
569 "column index {idx} out of bounds (row has {} columns)",
570 self.values.len()
571 ))
572 })?;
573 T::from_value(v)
574 }
575
576 pub fn get_by_name<T: FromValue>(&self, name: &str) -> Result<T> {
578 let idx = self
579 .columns
580 .iter()
581 .position(|c| c == name)
582 .ok_or_else(|| SQLRiteError::General(format!("no column named '{name}' in row")))?;
583 self.get(idx)
584 }
585
586 pub fn columns(&self) -> &[String] {
588 self.columns
589 }
590
591 pub fn to_owned_row(&self) -> OwnedRow {
594 OwnedRow {
595 columns: self.columns.to_vec(),
596 values: self.values.clone(),
597 }
598 }
599}
600
601#[derive(Debug, Clone)]
604pub struct OwnedRow {
605 pub columns: Vec<String>,
606 pub values: Vec<Value>,
607}
608
609impl OwnedRow {
610 pub fn get<T: FromValue>(&self, idx: usize) -> Result<T> {
611 let v = self.values.get(idx).ok_or_else(|| {
612 SQLRiteError::General(format!(
613 "column index {idx} out of bounds (row has {} columns)",
614 self.values.len()
615 ))
616 })?;
617 T::from_value(v)
618 }
619
620 pub fn get_by_name<T: FromValue>(&self, name: &str) -> Result<T> {
621 let idx = self
622 .columns
623 .iter()
624 .position(|c| c == name)
625 .ok_or_else(|| SQLRiteError::General(format!("no column named '{name}' in row")))?;
626 self.get(idx)
627 }
628}
629
630pub trait FromValue: Sized {
635 fn from_value(v: &Value) -> Result<Self>;
636}
637
638impl FromValue for i64 {
639 fn from_value(v: &Value) -> Result<Self> {
640 match v {
641 Value::Integer(n) => Ok(*n),
642 Value::Null => Err(SQLRiteError::General(
643 "expected Integer, got NULL".to_string(),
644 )),
645 other => Err(SQLRiteError::General(format!(
646 "cannot convert {other:?} to i64"
647 ))),
648 }
649 }
650}
651
652impl FromValue for f64 {
653 fn from_value(v: &Value) -> Result<Self> {
654 match v {
655 Value::Real(f) => Ok(*f),
656 Value::Integer(n) => Ok(*n as f64),
657 Value::Null => Err(SQLRiteError::General("expected Real, got NULL".to_string())),
658 other => Err(SQLRiteError::General(format!(
659 "cannot convert {other:?} to f64"
660 ))),
661 }
662 }
663}
664
665impl FromValue for String {
666 fn from_value(v: &Value) -> Result<Self> {
667 match v {
668 Value::Text(s) => Ok(s.clone()),
669 Value::Null => Err(SQLRiteError::General("expected Text, got NULL".to_string())),
670 other => Err(SQLRiteError::General(format!(
671 "cannot convert {other:?} to String"
672 ))),
673 }
674 }
675}
676
677impl FromValue for bool {
678 fn from_value(v: &Value) -> Result<Self> {
679 match v {
680 Value::Bool(b) => Ok(*b),
681 Value::Integer(n) => Ok(*n != 0),
682 Value::Null => Err(SQLRiteError::General("expected Bool, got NULL".to_string())),
683 other => Err(SQLRiteError::General(format!(
684 "cannot convert {other:?} to bool"
685 ))),
686 }
687 }
688}
689
690impl<T: FromValue> FromValue for Option<T> {
693 fn from_value(v: &Value) -> Result<Self> {
694 match v {
695 Value::Null => Ok(None),
696 other => Ok(Some(T::from_value(other)?)),
697 }
698 }
699}
700
701impl FromValue for Value {
704 fn from_value(v: &Value) -> Result<Self> {
705 Ok(v.clone())
706 }
707}
708
709#[cfg(test)]
710mod tests {
711 use super::*;
712
713 fn tmp_path(name: &str) -> std::path::PathBuf {
714 let mut p = std::env::temp_dir();
715 let pid = std::process::id();
716 let nanos = std::time::SystemTime::now()
717 .duration_since(std::time::UNIX_EPOCH)
718 .map(|d| d.as_nanos())
719 .unwrap_or(0);
720 p.push(format!("sqlrite-conn-{pid}-{nanos}-{name}.sqlrite"));
721 p
722 }
723
724 fn cleanup(path: &std::path::Path) {
725 let _ = std::fs::remove_file(path);
726 let mut wal = path.as_os_str().to_owned();
727 wal.push("-wal");
728 let _ = std::fs::remove_file(std::path::PathBuf::from(wal));
729 }
730
731 #[test]
732 fn in_memory_roundtrip() {
733 let mut conn = Connection::open_in_memory().unwrap();
734 conn.execute("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT, age INTEGER);")
735 .unwrap();
736 conn.execute("INSERT INTO users (name, age) VALUES ('alice', 30);")
737 .unwrap();
738 conn.execute("INSERT INTO users (name, age) VALUES ('bob', 25);")
739 .unwrap();
740
741 let stmt = conn.prepare("SELECT id, name, age FROM users;").unwrap();
742 let mut rows = stmt.query().unwrap();
743 assert_eq!(rows.columns(), &["id", "name", "age"]);
744 let mut collected: Vec<(i64, String, i64)> = Vec::new();
745 while let Some(row) = rows.next().unwrap() {
746 collected.push((
747 row.get::<i64>(0).unwrap(),
748 row.get::<String>(1).unwrap(),
749 row.get::<i64>(2).unwrap(),
750 ));
751 }
752 assert_eq!(collected.len(), 2);
753 assert!(collected.iter().any(|(_, n, a)| n == "alice" && *a == 30));
754 assert!(collected.iter().any(|(_, n, a)| n == "bob" && *a == 25));
755 }
756
757 #[test]
758 fn file_backed_persists_across_connections() {
759 let path = tmp_path("persist");
760 {
761 let mut c1 = Connection::open(&path).unwrap();
762 c1.execute("CREATE TABLE items (id INTEGER PRIMARY KEY, label TEXT);")
763 .unwrap();
764 c1.execute("INSERT INTO items (label) VALUES ('one');")
765 .unwrap();
766 }
767 {
768 let mut c2 = Connection::open(&path).unwrap();
769 let stmt = c2.prepare("SELECT label FROM items;").unwrap();
770 let mut rows = stmt.query().unwrap();
771 let first = rows.next().unwrap().expect("one row");
772 assert_eq!(first.get::<String>(0).unwrap(), "one");
773 assert!(rows.next().unwrap().is_none());
774 }
775 cleanup(&path);
776 }
777
778 #[test]
779 fn read_only_connection_rejects_writes() {
780 let path = tmp_path("ro_reject");
781 {
782 let mut c = Connection::open(&path).unwrap();
783 c.execute("CREATE TABLE t (id INTEGER PRIMARY KEY);")
784 .unwrap();
785 c.execute("INSERT INTO t (id) VALUES (1);").unwrap();
786 } let mut ro = Connection::open_read_only(&path).unwrap();
789 assert!(ro.is_read_only());
790 let err = ro.execute("INSERT INTO t (id) VALUES (2);").unwrap_err();
791 assert!(format!("{err}").contains("read-only"));
792 cleanup(&path);
793 }
794
795 #[test]
796 fn transactions_work_through_connection() {
797 let mut conn = Connection::open_in_memory().unwrap();
798 conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY, x INTEGER);")
799 .unwrap();
800 conn.execute("INSERT INTO t (x) VALUES (1);").unwrap();
801
802 conn.execute("BEGIN;").unwrap();
803 assert!(conn.in_transaction());
804 conn.execute("INSERT INTO t (x) VALUES (2);").unwrap();
805 conn.execute("ROLLBACK;").unwrap();
806 assert!(!conn.in_transaction());
807
808 let stmt = conn.prepare("SELECT x FROM t;").unwrap();
809 let rows = stmt.query().unwrap().collect_all().unwrap();
810 assert_eq!(rows.len(), 1);
811 assert_eq!(rows[0].get::<i64>(0).unwrap(), 1);
812 }
813
814 #[test]
815 fn get_by_name_works() {
816 let mut conn = Connection::open_in_memory().unwrap();
817 conn.execute("CREATE TABLE t (a INTEGER, b TEXT);").unwrap();
818 conn.execute("INSERT INTO t (a, b) VALUES (42, 'hello');")
819 .unwrap();
820
821 let stmt = conn.prepare("SELECT a, b FROM t;").unwrap();
822 let mut rows = stmt.query().unwrap();
823 let row = rows.next().unwrap().unwrap();
824 assert_eq!(row.get_by_name::<i64>("a").unwrap(), 42);
825 assert_eq!(row.get_by_name::<String>("b").unwrap(), "hello");
826 }
827
828 #[test]
829 fn null_column_maps_to_none() {
830 let mut conn = Connection::open_in_memory().unwrap();
831 conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY, note TEXT);")
832 .unwrap();
833 conn.execute("INSERT INTO t (id) VALUES (1);").unwrap();
835
836 let stmt = conn.prepare("SELECT id, note FROM t;").unwrap();
837 let mut rows = stmt.query().unwrap();
838 let row = rows.next().unwrap().unwrap();
839 assert_eq!(row.get::<i64>(0).unwrap(), 1);
840 assert_eq!(row.get::<Option<String>>(1).unwrap(), None);
842 }
843
844 #[test]
845 fn prepare_rejects_multiple_statements() {
846 let mut conn = Connection::open_in_memory().unwrap();
847 let err = conn.prepare("SELECT 1; SELECT 2;").unwrap_err();
848 assert!(format!("{err}").contains("single statement"));
849 }
850
851 #[test]
852 fn query_on_non_select_errors() {
853 let mut conn = Connection::open_in_memory().unwrap();
854 conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY);")
855 .unwrap();
856 let stmt = conn.prepare("INSERT INTO t VALUES (1);").unwrap();
857 let err = stmt.query().unwrap_err();
858 assert!(format!("{err}").contains("SELECT"));
859 }
860
861 #[test]
864 fn auto_vacuum_threshold_default_and_setter() {
865 let mut conn = Connection::open_in_memory().unwrap();
866 assert_eq!(
867 conn.auto_vacuum_threshold(),
868 Some(0.25),
869 "fresh connection should ship with the SQLite-parity default"
870 );
871
872 conn.set_auto_vacuum_threshold(None).unwrap();
873 assert_eq!(conn.auto_vacuum_threshold(), None);
874
875 conn.set_auto_vacuum_threshold(Some(0.5)).unwrap();
876 assert_eq!(conn.auto_vacuum_threshold(), Some(0.5));
877
878 let err = conn.set_auto_vacuum_threshold(Some(1.5)).unwrap_err();
881 assert!(
882 format!("{err}").contains("auto_vacuum_threshold"),
883 "expected typed range error, got: {err}"
884 );
885 assert_eq!(
886 conn.auto_vacuum_threshold(),
887 Some(0.5),
888 "rejected setter call must not mutate the threshold"
889 );
890 }
891
892 #[test]
893 fn index_out_of_bounds_errors_cleanly() {
894 let mut conn = Connection::open_in_memory().unwrap();
895 conn.execute("CREATE TABLE t (a INTEGER PRIMARY KEY);")
896 .unwrap();
897 conn.execute("INSERT INTO t (a) VALUES (1);").unwrap();
898 let stmt = conn.prepare("SELECT a FROM t;").unwrap();
899 let mut rows = stmt.query().unwrap();
900 let row = rows.next().unwrap().unwrap();
901 let err = row.get::<i64>(99).unwrap_err();
902 assert!(format!("{err}").contains("out of bounds"));
903 }
904
905 #[test]
910 fn parameter_count_reflects_question_marks() {
911 let mut conn = Connection::open_in_memory().unwrap();
912 conn.execute("CREATE TABLE t (a INTEGER, b TEXT);").unwrap();
913 let stmt = conn.prepare("SELECT a, b FROM t WHERE a = ?").unwrap();
914 assert_eq!(stmt.parameter_count(), 1);
915 let stmt = conn
916 .prepare("SELECT a, b FROM t WHERE a = ? AND b = ?")
917 .unwrap();
918 assert_eq!(stmt.parameter_count(), 2);
919 let stmt = conn.prepare("SELECT a FROM t").unwrap();
920 assert_eq!(stmt.parameter_count(), 0);
921 }
922
923 #[test]
924 fn query_with_params_binds_scalars() {
925 let mut conn = Connection::open_in_memory().unwrap();
926 conn.execute("CREATE TABLE t (a INTEGER PRIMARY KEY, b TEXT);")
927 .unwrap();
928 conn.execute("INSERT INTO t (a, b) VALUES (1, 'alice');")
929 .unwrap();
930 conn.execute("INSERT INTO t (a, b) VALUES (2, 'bob');")
931 .unwrap();
932 conn.execute("INSERT INTO t (a, b) VALUES (3, 'carol');")
933 .unwrap();
934
935 let stmt = conn.prepare("SELECT b FROM t WHERE a = ?").unwrap();
936 let rows = stmt
937 .query_with_params(&[Value::Integer(2)])
938 .unwrap()
939 .collect_all()
940 .unwrap();
941 assert_eq!(rows.len(), 1);
942 assert_eq!(rows[0].get::<String>(0).unwrap(), "bob");
943 }
944
945 #[test]
946 fn execute_with_params_binds_insert_values() {
947 let mut conn = Connection::open_in_memory().unwrap();
948 conn.execute("CREATE TABLE t (a INTEGER, b TEXT);").unwrap();
949
950 let mut stmt = conn.prepare("INSERT INTO t (a, b) VALUES (?, ?)").unwrap();
951 stmt.execute_with_params(&[Value::Integer(7), Value::Text("hi".into())])
952 .unwrap();
953 stmt.execute_with_params(&[Value::Integer(8), Value::Text("yo".into())])
954 .unwrap();
955
956 let stmt = conn.prepare("SELECT a, b FROM t").unwrap();
957 let rows = stmt.query().unwrap().collect_all().unwrap();
958 assert_eq!(rows.len(), 2);
959 assert!(
960 rows.iter()
961 .any(|r| r.get::<i64>(0).unwrap() == 7 && r.get::<String>(1).unwrap() == "hi")
962 );
963 assert!(
964 rows.iter()
965 .any(|r| r.get::<i64>(0).unwrap() == 8 && r.get::<String>(1).unwrap() == "yo")
966 );
967 }
968
969 #[test]
970 fn arity_mismatch_returns_clean_error() {
971 let mut conn = Connection::open_in_memory().unwrap();
972 conn.execute("CREATE TABLE t (a INTEGER, b TEXT);").unwrap();
973 let stmt = conn
974 .prepare("SELECT * FROM t WHERE a = ? AND b = ?")
975 .unwrap();
976 let err = stmt.query_with_params(&[Value::Integer(1)]).unwrap_err();
977 assert!(format!("{err}").contains("expected 2 parameter"));
978 }
979
980 #[test]
981 fn run_and_query_reject_when_placeholders_present() {
982 let mut conn = Connection::open_in_memory().unwrap();
983 conn.execute("CREATE TABLE t (a INTEGER);").unwrap();
984 let mut stmt_select = conn.prepare("SELECT a FROM t WHERE a = ?").unwrap();
985 let err = stmt_select.query().unwrap_err();
986 assert!(format!("{err}").contains("query_with_params"));
987 let err = stmt_select.run().unwrap_err();
988 assert!(format!("{err}").contains("execute_with_params"));
989 }
990
991 #[test]
992 fn null_param_compares_against_null() {
993 let mut conn = Connection::open_in_memory().unwrap();
997 conn.execute("CREATE TABLE t (a INTEGER);").unwrap();
998 conn.execute("INSERT INTO t (a) VALUES (1);").unwrap();
999 let stmt = conn.prepare("SELECT a FROM t WHERE a = ?").unwrap();
1000 let rows = stmt
1001 .query_with_params(&[Value::Null])
1002 .unwrap()
1003 .collect_all()
1004 .unwrap();
1005 assert_eq!(rows.len(), 0);
1006 }
1007
1008 #[test]
1009 fn vector_param_substitutes_through_select() {
1010 let mut conn = Connection::open_in_memory().unwrap();
1014 conn.execute("CREATE TABLE v (id INTEGER PRIMARY KEY, e VECTOR(3));")
1015 .unwrap();
1016 conn.execute("INSERT INTO v (id, e) VALUES (1, [1.0, 0.0, 0.0]);")
1017 .unwrap();
1018 conn.execute("INSERT INTO v (id, e) VALUES (2, [0.0, 1.0, 0.0]);")
1019 .unwrap();
1020 conn.execute("INSERT INTO v (id, e) VALUES (3, [0.0, 0.0, 1.0]);")
1021 .unwrap();
1022
1023 let stmt = conn
1024 .prepare("SELECT id FROM v ORDER BY vec_distance_l2(e, ?) ASC LIMIT 1")
1025 .unwrap();
1026 let rows = stmt
1027 .query_with_params(&[Value::Vector(vec![1.0, 0.0, 0.0])])
1028 .unwrap()
1029 .collect_all()
1030 .unwrap();
1031 assert_eq!(rows.len(), 1);
1032 assert_eq!(rows[0].get::<i64>(0).unwrap(), 1);
1033 }
1034
1035 #[test]
1036 fn prepare_cached_reuses_plans() {
1037 let mut conn = Connection::open_in_memory().unwrap();
1038 conn.execute("CREATE TABLE t (a INTEGER);").unwrap();
1039 for n in 1..=3 {
1040 conn.execute(&format!("INSERT INTO t (a) VALUES ({n});"))
1041 .unwrap();
1042 }
1043
1044 let _ = conn.prepare_cached("SELECT a FROM t WHERE a = ?").unwrap();
1046 let _ = conn.prepare_cached("SELECT a FROM t WHERE a = ?").unwrap();
1047 assert_eq!(conn.prepared_cache_len(), 1);
1048
1049 let _ = conn.prepare_cached("SELECT a FROM t").unwrap();
1051 assert_eq!(conn.prepared_cache_len(), 2);
1052 }
1053
1054 #[test]
1055 fn prepare_cached_evicts_when_over_capacity() {
1056 let mut conn = Connection::open_in_memory().unwrap();
1057 conn.execute("CREATE TABLE t (a INTEGER);").unwrap();
1058 conn.set_prepared_cache_capacity(2);
1059 let _ = conn.prepare_cached("SELECT a FROM t").unwrap();
1060 let _ = conn.prepare_cached("SELECT a FROM t WHERE a = ?").unwrap();
1061 assert_eq!(conn.prepared_cache_len(), 2);
1062 let _ = conn.prepare_cached("SELECT a FROM t WHERE a > ?").unwrap();
1064 assert_eq!(conn.prepared_cache_len(), 2);
1065 }
1066
1067 #[test]
1075 fn vector_bind_through_hnsw_optimizer() {
1076 let mut conn = Connection::open_in_memory().unwrap();
1077 conn.execute("CREATE TABLE v (id INTEGER PRIMARY KEY, e VECTOR(4));")
1078 .unwrap();
1079 let corpus: [(i64, [f32; 4]); 5] = [
1080 (1, [1.0, 0.0, 0.0, 0.0]),
1081 (2, [0.0, 1.0, 0.0, 0.0]),
1082 (3, [0.0, 0.0, 1.0, 0.0]),
1083 (4, [0.0, 0.0, 0.0, 1.0]),
1084 (5, [0.5, 0.5, 0.5, 0.5]),
1085 ];
1086 for (id, vec) in corpus {
1087 conn.execute(&format!(
1088 "INSERT INTO v (id, e) VALUES ({id}, [{}, {}, {}, {}]);",
1089 vec[0], vec[1], vec[2], vec[3]
1090 ))
1091 .unwrap();
1092 }
1093 conn.execute("CREATE INDEX v_hnsw ON v USING hnsw (e);")
1094 .unwrap();
1095
1096 let stmt = conn
1097 .prepare("SELECT id FROM v ORDER BY vec_distance_l2(e, ?) ASC LIMIT 1")
1098 .unwrap();
1099 let rows = stmt
1101 .query_with_params(&[Value::Vector(vec![0.0, 0.0, 1.0, 0.0])])
1102 .unwrap()
1103 .collect_all()
1104 .unwrap();
1105 assert_eq!(rows.len(), 1);
1106 assert_eq!(rows[0].get::<i64>(0).unwrap(), 3);
1107
1108 let rows = stmt
1110 .query_with_params(&[Value::Vector(vec![1.0, 0.0, 0.0, 0.0])])
1111 .unwrap()
1112 .collect_all()
1113 .unwrap();
1114 assert_eq!(rows.len(), 1);
1115 assert_eq!(rows[0].get::<i64>(0).unwrap(), 1);
1116 }
1117
1118 #[test]
1119 fn prepare_cached_executes_the_same_as_prepare() {
1120 let mut conn = Connection::open_in_memory().unwrap();
1121 conn.execute("CREATE TABLE t (a INTEGER PRIMARY KEY, b TEXT);")
1122 .unwrap();
1123 let mut ins = conn
1124 .prepare_cached("INSERT INTO t (a, b) VALUES (?, ?)")
1125 .unwrap();
1126 ins.execute_with_params(&[Value::Integer(1), Value::Text("alpha".into())])
1127 .unwrap();
1128 ins.execute_with_params(&[Value::Integer(2), Value::Text("beta".into())])
1129 .unwrap();
1130
1131 let stmt = conn.prepare_cached("SELECT b FROM t WHERE a = ?").unwrap();
1132 let rows = stmt
1133 .query_with_params(&[Value::Integer(2)])
1134 .unwrap()
1135 .collect_all()
1136 .unwrap();
1137 assert_eq!(rows.len(), 1);
1138 assert_eq!(rows[0].get::<String>(0).unwrap(), "beta");
1139 }
1140}