1use std::path::Path;
38
39use sqlparser::dialect::SQLiteDialect;
40use sqlparser::parser::Parser;
41
42use crate::error::{Result, SQLRiteError};
43use crate::sql::db::database::Database;
44use crate::sql::db::table::Value;
45use crate::sql::executor::execute_select_rows;
46use crate::sql::pager::{AccessMode, open_database_with_mode, save_database};
47use crate::sql::parser::select::SelectQuery;
48use crate::sql::process_command;
49
50pub struct Connection {
70 db: Database,
71}
72
73impl Connection {
74 pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
81 let path = path.as_ref();
82 let db_name = path
83 .file_stem()
84 .and_then(|s| s.to_str())
85 .unwrap_or("db")
86 .to_string();
87 let db = if path.exists() {
88 open_database_with_mode(path, db_name, AccessMode::ReadWrite)?
89 } else {
90 let mut fresh = Database::new(db_name);
96 fresh.source_path = Some(path.to_path_buf());
97 save_database(&mut fresh, path)?;
98 fresh
99 };
100 Ok(Self { db })
101 }
102
103 pub fn open_read_only<P: AsRef<Path>>(path: P) -> Result<Self> {
109 let path = path.as_ref();
110 let db_name = path
111 .file_stem()
112 .and_then(|s| s.to_str())
113 .unwrap_or("db")
114 .to_string();
115 let db = open_database_with_mode(path, db_name, AccessMode::ReadOnly)?;
116 Ok(Self { db })
117 }
118
119 pub fn open_in_memory() -> Result<Self> {
123 Ok(Self {
124 db: Database::new("memdb".to_string()),
125 })
126 }
127
128 pub fn execute(&mut self, sql: &str) -> Result<String> {
138 process_command(sql, &mut self.db)
139 }
140
141 pub fn prepare<'c>(&'c mut self, sql: &str) -> Result<Statement<'c>> {
147 Statement::new(self, sql)
148 }
149
150 pub fn in_transaction(&self) -> bool {
153 self.db.in_transaction()
154 }
155
156 pub fn auto_vacuum_threshold(&self) -> Option<f32> {
163 self.db.auto_vacuum_threshold()
164 }
165
166 pub fn set_auto_vacuum_threshold(&mut self, threshold: Option<f32>) -> Result<()> {
177 self.db.set_auto_vacuum_threshold(threshold)
178 }
179
180 pub fn is_read_only(&self) -> bool {
183 self.db.is_read_only()
184 }
185
186 #[doc(hidden)]
190 pub fn database(&self) -> &Database {
191 &self.db
192 }
193
194 #[doc(hidden)]
195 pub fn database_mut(&mut self) -> &mut Database {
196 &mut self.db
197 }
198}
199
200impl std::fmt::Debug for Connection {
201 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
202 f.debug_struct("Connection")
203 .field("in_transaction", &self.db.in_transaction())
204 .field("read_only", &self.db.is_read_only())
205 .field("tables", &self.db.tables.len())
206 .finish()
207 }
208}
209
210pub struct Statement<'c> {
214 conn: &'c mut Connection,
215 sql: String,
216 kind: StatementKind,
217}
218
219enum StatementKind {
220 Select(SelectQuery),
221 Other,
222}
223
224impl std::fmt::Debug for Statement<'_> {
225 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
226 f.debug_struct("Statement")
227 .field("sql", &self.sql)
228 .field(
229 "kind",
230 &match self.kind {
231 StatementKind::Select(_) => "Select",
232 StatementKind::Other => "Other",
233 },
234 )
235 .finish()
236 }
237}
238
239impl<'c> Statement<'c> {
240 fn new(conn: &'c mut Connection, sql: &str) -> Result<Self> {
241 let dialect = SQLiteDialect {};
243 let mut ast = Parser::parse_sql(&dialect, sql).map_err(SQLRiteError::from)?;
244 let Some(stmt) = ast.pop() else {
245 return Err(SQLRiteError::General("no statement to prepare".to_string()));
246 };
247 if !ast.is_empty() {
248 return Err(SQLRiteError::General(
249 "prepare() accepts a single statement; found more than one".to_string(),
250 ));
251 }
252 let kind = match &stmt {
253 sqlparser::ast::Statement::Query(_) => StatementKind::Select(SelectQuery::new(&stmt)?),
254 _ => StatementKind::Other,
255 };
256 Ok(Self {
257 conn,
258 sql: sql.to_string(),
259 kind,
260 })
261 }
262
263 pub fn run(&mut self) -> Result<String> {
268 self.conn.execute(&self.sql)
269 }
270
271 pub fn query(&self) -> Result<Rows> {
274 match &self.kind {
275 StatementKind::Select(sq) => {
276 let result = execute_select_rows(sq.clone(), &self.conn.db)?;
277 Ok(Rows {
278 columns: result.columns,
279 rows: result.rows.into_iter(),
280 })
281 }
282 StatementKind::Other => Err(SQLRiteError::General(
283 "query() only works on SELECT statements; use run() for DDL/DML".to_string(),
284 )),
285 }
286 }
287
288 pub fn column_names(&self) -> Option<Vec<String>> {
291 match &self.kind {
292 StatementKind::Select(_) => {
293 None
298 }
299 StatementKind::Other => None,
300 }
301 }
302}
303
304pub struct Rows {
313 columns: Vec<String>,
314 rows: std::vec::IntoIter<Vec<Value>>,
315}
316
317impl std::fmt::Debug for Rows {
318 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
319 f.debug_struct("Rows")
320 .field("columns", &self.columns)
321 .field("remaining", &self.rows.len())
322 .finish()
323 }
324}
325
326impl Rows {
327 pub fn columns(&self) -> &[String] {
329 &self.columns
330 }
331
332 pub fn next(&mut self) -> Result<Option<Row<'_>>> {
337 Ok(self.rows.next().map(|values| Row {
338 columns: &self.columns,
339 values,
340 }))
341 }
342
343 pub fn collect_all(mut self) -> Result<Vec<OwnedRow>> {
347 let mut out = Vec::new();
348 while let Some(r) = self.next()? {
349 out.push(r.to_owned_row());
350 }
351 Ok(out)
352 }
353}
354
355pub struct Row<'r> {
359 columns: &'r [String],
360 values: Vec<Value>,
361}
362
363impl<'r> Row<'r> {
364 pub fn get<T: FromValue>(&self, idx: usize) -> Result<T> {
367 let v = self.values.get(idx).ok_or_else(|| {
368 SQLRiteError::General(format!(
369 "column index {idx} out of bounds (row has {} columns)",
370 self.values.len()
371 ))
372 })?;
373 T::from_value(v)
374 }
375
376 pub fn get_by_name<T: FromValue>(&self, name: &str) -> Result<T> {
378 let idx = self
379 .columns
380 .iter()
381 .position(|c| c == name)
382 .ok_or_else(|| SQLRiteError::General(format!("no column named '{name}' in row")))?;
383 self.get(idx)
384 }
385
386 pub fn columns(&self) -> &[String] {
388 self.columns
389 }
390
391 pub fn to_owned_row(&self) -> OwnedRow {
394 OwnedRow {
395 columns: self.columns.to_vec(),
396 values: self.values.clone(),
397 }
398 }
399}
400
401#[derive(Debug, Clone)]
404pub struct OwnedRow {
405 pub columns: Vec<String>,
406 pub values: Vec<Value>,
407}
408
409impl OwnedRow {
410 pub fn get<T: FromValue>(&self, idx: usize) -> Result<T> {
411 let v = self.values.get(idx).ok_or_else(|| {
412 SQLRiteError::General(format!(
413 "column index {idx} out of bounds (row has {} columns)",
414 self.values.len()
415 ))
416 })?;
417 T::from_value(v)
418 }
419
420 pub fn get_by_name<T: FromValue>(&self, name: &str) -> Result<T> {
421 let idx = self
422 .columns
423 .iter()
424 .position(|c| c == name)
425 .ok_or_else(|| SQLRiteError::General(format!("no column named '{name}' in row")))?;
426 self.get(idx)
427 }
428}
429
430pub trait FromValue: Sized {
435 fn from_value(v: &Value) -> Result<Self>;
436}
437
438impl FromValue for i64 {
439 fn from_value(v: &Value) -> Result<Self> {
440 match v {
441 Value::Integer(n) => Ok(*n),
442 Value::Null => Err(SQLRiteError::General(
443 "expected Integer, got NULL".to_string(),
444 )),
445 other => Err(SQLRiteError::General(format!(
446 "cannot convert {other:?} to i64"
447 ))),
448 }
449 }
450}
451
452impl FromValue for f64 {
453 fn from_value(v: &Value) -> Result<Self> {
454 match v {
455 Value::Real(f) => Ok(*f),
456 Value::Integer(n) => Ok(*n as f64),
457 Value::Null => Err(SQLRiteError::General("expected Real, got NULL".to_string())),
458 other => Err(SQLRiteError::General(format!(
459 "cannot convert {other:?} to f64"
460 ))),
461 }
462 }
463}
464
465impl FromValue for String {
466 fn from_value(v: &Value) -> Result<Self> {
467 match v {
468 Value::Text(s) => Ok(s.clone()),
469 Value::Null => Err(SQLRiteError::General("expected Text, got NULL".to_string())),
470 other => Err(SQLRiteError::General(format!(
471 "cannot convert {other:?} to String"
472 ))),
473 }
474 }
475}
476
477impl FromValue for bool {
478 fn from_value(v: &Value) -> Result<Self> {
479 match v {
480 Value::Bool(b) => Ok(*b),
481 Value::Integer(n) => Ok(*n != 0),
482 Value::Null => Err(SQLRiteError::General("expected Bool, got NULL".to_string())),
483 other => Err(SQLRiteError::General(format!(
484 "cannot convert {other:?} to bool"
485 ))),
486 }
487 }
488}
489
490impl<T: FromValue> FromValue for Option<T> {
493 fn from_value(v: &Value) -> Result<Self> {
494 match v {
495 Value::Null => Ok(None),
496 other => Ok(Some(T::from_value(other)?)),
497 }
498 }
499}
500
501impl FromValue for Value {
504 fn from_value(v: &Value) -> Result<Self> {
505 Ok(v.clone())
506 }
507}
508
509#[cfg(test)]
510mod tests {
511 use super::*;
512
513 fn tmp_path(name: &str) -> std::path::PathBuf {
514 let mut p = std::env::temp_dir();
515 let pid = std::process::id();
516 let nanos = std::time::SystemTime::now()
517 .duration_since(std::time::UNIX_EPOCH)
518 .map(|d| d.as_nanos())
519 .unwrap_or(0);
520 p.push(format!("sqlrite-conn-{pid}-{nanos}-{name}.sqlrite"));
521 p
522 }
523
524 fn cleanup(path: &std::path::Path) {
525 let _ = std::fs::remove_file(path);
526 let mut wal = path.as_os_str().to_owned();
527 wal.push("-wal");
528 let _ = std::fs::remove_file(std::path::PathBuf::from(wal));
529 }
530
531 #[test]
532 fn in_memory_roundtrip() {
533 let mut conn = Connection::open_in_memory().unwrap();
534 conn.execute("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT, age INTEGER);")
535 .unwrap();
536 conn.execute("INSERT INTO users (name, age) VALUES ('alice', 30);")
537 .unwrap();
538 conn.execute("INSERT INTO users (name, age) VALUES ('bob', 25);")
539 .unwrap();
540
541 let stmt = conn.prepare("SELECT id, name, age FROM users;").unwrap();
542 let mut rows = stmt.query().unwrap();
543 assert_eq!(rows.columns(), &["id", "name", "age"]);
544 let mut collected: Vec<(i64, String, i64)> = Vec::new();
545 while let Some(row) = rows.next().unwrap() {
546 collected.push((
547 row.get::<i64>(0).unwrap(),
548 row.get::<String>(1).unwrap(),
549 row.get::<i64>(2).unwrap(),
550 ));
551 }
552 assert_eq!(collected.len(), 2);
553 assert!(collected.iter().any(|(_, n, a)| n == "alice" && *a == 30));
554 assert!(collected.iter().any(|(_, n, a)| n == "bob" && *a == 25));
555 }
556
557 #[test]
558 fn file_backed_persists_across_connections() {
559 let path = tmp_path("persist");
560 {
561 let mut c1 = Connection::open(&path).unwrap();
562 c1.execute("CREATE TABLE items (id INTEGER PRIMARY KEY, label TEXT);")
563 .unwrap();
564 c1.execute("INSERT INTO items (label) VALUES ('one');")
565 .unwrap();
566 }
567 {
568 let mut c2 = Connection::open(&path).unwrap();
569 let stmt = c2.prepare("SELECT label FROM items;").unwrap();
570 let mut rows = stmt.query().unwrap();
571 let first = rows.next().unwrap().expect("one row");
572 assert_eq!(first.get::<String>(0).unwrap(), "one");
573 assert!(rows.next().unwrap().is_none());
574 }
575 cleanup(&path);
576 }
577
578 #[test]
579 fn read_only_connection_rejects_writes() {
580 let path = tmp_path("ro_reject");
581 {
582 let mut c = Connection::open(&path).unwrap();
583 c.execute("CREATE TABLE t (id INTEGER PRIMARY KEY);")
584 .unwrap();
585 c.execute("INSERT INTO t (id) VALUES (1);").unwrap();
586 } let mut ro = Connection::open_read_only(&path).unwrap();
589 assert!(ro.is_read_only());
590 let err = ro.execute("INSERT INTO t (id) VALUES (2);").unwrap_err();
591 assert!(format!("{err}").contains("read-only"));
592 cleanup(&path);
593 }
594
595 #[test]
596 fn transactions_work_through_connection() {
597 let mut conn = Connection::open_in_memory().unwrap();
598 conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY, x INTEGER);")
599 .unwrap();
600 conn.execute("INSERT INTO t (x) VALUES (1);").unwrap();
601
602 conn.execute("BEGIN;").unwrap();
603 assert!(conn.in_transaction());
604 conn.execute("INSERT INTO t (x) VALUES (2);").unwrap();
605 conn.execute("ROLLBACK;").unwrap();
606 assert!(!conn.in_transaction());
607
608 let stmt = conn.prepare("SELECT x FROM t;").unwrap();
609 let rows = stmt.query().unwrap().collect_all().unwrap();
610 assert_eq!(rows.len(), 1);
611 assert_eq!(rows[0].get::<i64>(0).unwrap(), 1);
612 }
613
614 #[test]
615 fn get_by_name_works() {
616 let mut conn = Connection::open_in_memory().unwrap();
617 conn.execute("CREATE TABLE t (a INTEGER, b TEXT);").unwrap();
618 conn.execute("INSERT INTO t (a, b) VALUES (42, 'hello');")
619 .unwrap();
620
621 let stmt = conn.prepare("SELECT a, b FROM t;").unwrap();
622 let mut rows = stmt.query().unwrap();
623 let row = rows.next().unwrap().unwrap();
624 assert_eq!(row.get_by_name::<i64>("a").unwrap(), 42);
625 assert_eq!(row.get_by_name::<String>("b").unwrap(), "hello");
626 }
627
628 #[test]
629 fn null_column_maps_to_none() {
630 let mut conn = Connection::open_in_memory().unwrap();
631 conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY, note TEXT);")
632 .unwrap();
633 conn.execute("INSERT INTO t (id) VALUES (1);").unwrap();
635
636 let stmt = conn.prepare("SELECT id, note FROM t;").unwrap();
637 let mut rows = stmt.query().unwrap();
638 let row = rows.next().unwrap().unwrap();
639 assert_eq!(row.get::<i64>(0).unwrap(), 1);
640 assert_eq!(row.get::<Option<String>>(1).unwrap(), None);
642 }
643
644 #[test]
645 fn prepare_rejects_multiple_statements() {
646 let mut conn = Connection::open_in_memory().unwrap();
647 let err = conn.prepare("SELECT 1; SELECT 2;").unwrap_err();
648 assert!(format!("{err}").contains("single statement"));
649 }
650
651 #[test]
652 fn query_on_non_select_errors() {
653 let mut conn = Connection::open_in_memory().unwrap();
654 conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY);")
655 .unwrap();
656 let stmt = conn.prepare("INSERT INTO t VALUES (1);").unwrap();
657 let err = stmt.query().unwrap_err();
658 assert!(format!("{err}").contains("SELECT"));
659 }
660
661 #[test]
664 fn auto_vacuum_threshold_default_and_setter() {
665 let mut conn = Connection::open_in_memory().unwrap();
666 assert_eq!(
667 conn.auto_vacuum_threshold(),
668 Some(0.25),
669 "fresh connection should ship with the SQLite-parity default"
670 );
671
672 conn.set_auto_vacuum_threshold(None).unwrap();
673 assert_eq!(conn.auto_vacuum_threshold(), None);
674
675 conn.set_auto_vacuum_threshold(Some(0.5)).unwrap();
676 assert_eq!(conn.auto_vacuum_threshold(), Some(0.5));
677
678 let err = conn.set_auto_vacuum_threshold(Some(1.5)).unwrap_err();
681 assert!(
682 format!("{err}").contains("auto_vacuum_threshold"),
683 "expected typed range error, got: {err}"
684 );
685 assert_eq!(
686 conn.auto_vacuum_threshold(),
687 Some(0.5),
688 "rejected setter call must not mutate the threshold"
689 );
690 }
691
692 #[test]
693 fn index_out_of_bounds_errors_cleanly() {
694 let mut conn = Connection::open_in_memory().unwrap();
695 conn.execute("CREATE TABLE t (a INTEGER PRIMARY KEY);")
696 .unwrap();
697 conn.execute("INSERT INTO t (a) VALUES (1);").unwrap();
698 let stmt = conn.prepare("SELECT a FROM t;").unwrap();
699 let mut rows = stmt.query().unwrap();
700 let row = rows.next().unwrap().unwrap();
701 let err = row.get::<i64>(99).unwrap_err();
702 assert!(format!("{err}").contains("out of bounds"));
703 }
704}