1use std::path::Path;
38
39use sqlparser::dialect::SQLiteDialect;
40use sqlparser::parser::Parser;
41
42use crate::error::{Result, SQLRiteError};
43use crate::sql::db::database::Database;
44use crate::sql::db::table::Value;
45use crate::sql::executor::execute_select_rows;
46use crate::sql::pager::{AccessMode, open_database_with_mode, save_database};
47use crate::sql::parser::select::SelectQuery;
48use crate::sql::process_command;
49
50pub struct Connection {
70 db: Database,
71}
72
73impl Connection {
74 pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
81 let path = path.as_ref();
82 let db_name = path
83 .file_stem()
84 .and_then(|s| s.to_str())
85 .unwrap_or("db")
86 .to_string();
87 let db = if path.exists() {
88 open_database_with_mode(path, db_name, AccessMode::ReadWrite)?
89 } else {
90 let mut fresh = Database::new(db_name);
96 fresh.source_path = Some(path.to_path_buf());
97 save_database(&mut fresh, path)?;
98 fresh
99 };
100 Ok(Self { db })
101 }
102
103 pub fn open_read_only<P: AsRef<Path>>(path: P) -> Result<Self> {
109 let path = path.as_ref();
110 let db_name = path
111 .file_stem()
112 .and_then(|s| s.to_str())
113 .unwrap_or("db")
114 .to_string();
115 let db = open_database_with_mode(path, db_name, AccessMode::ReadOnly)?;
116 Ok(Self { db })
117 }
118
119 pub fn open_in_memory() -> Result<Self> {
123 Ok(Self {
124 db: Database::new("memdb".to_string()),
125 })
126 }
127
128 pub fn execute(&mut self, sql: &str) -> Result<String> {
138 process_command(sql, &mut self.db)
139 }
140
141 pub fn prepare<'c>(&'c mut self, sql: &str) -> Result<Statement<'c>> {
147 Statement::new(self, sql)
148 }
149
150 pub fn in_transaction(&self) -> bool {
153 self.db.in_transaction()
154 }
155
156 pub fn is_read_only(&self) -> bool {
159 self.db.is_read_only()
160 }
161
162 #[doc(hidden)]
166 pub fn database(&self) -> &Database {
167 &self.db
168 }
169
170 #[doc(hidden)]
171 pub fn database_mut(&mut self) -> &mut Database {
172 &mut self.db
173 }
174}
175
176impl std::fmt::Debug for Connection {
177 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
178 f.debug_struct("Connection")
179 .field("in_transaction", &self.db.in_transaction())
180 .field("read_only", &self.db.is_read_only())
181 .field("tables", &self.db.tables.len())
182 .finish()
183 }
184}
185
186pub struct Statement<'c> {
190 conn: &'c mut Connection,
191 sql: String,
192 kind: StatementKind,
193}
194
195enum StatementKind {
196 Select(SelectQuery),
197 Other,
198}
199
200impl std::fmt::Debug for Statement<'_> {
201 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
202 f.debug_struct("Statement")
203 .field("sql", &self.sql)
204 .field(
205 "kind",
206 &match self.kind {
207 StatementKind::Select(_) => "Select",
208 StatementKind::Other => "Other",
209 },
210 )
211 .finish()
212 }
213}
214
215impl<'c> Statement<'c> {
216 fn new(conn: &'c mut Connection, sql: &str) -> Result<Self> {
217 let dialect = SQLiteDialect {};
219 let mut ast = Parser::parse_sql(&dialect, sql).map_err(SQLRiteError::from)?;
220 let Some(stmt) = ast.pop() else {
221 return Err(SQLRiteError::General("no statement to prepare".to_string()));
222 };
223 if !ast.is_empty() {
224 return Err(SQLRiteError::General(
225 "prepare() accepts a single statement; found more than one".to_string(),
226 ));
227 }
228 let kind = match &stmt {
229 sqlparser::ast::Statement::Query(_) => StatementKind::Select(SelectQuery::new(&stmt)?),
230 _ => StatementKind::Other,
231 };
232 Ok(Self {
233 conn,
234 sql: sql.to_string(),
235 kind,
236 })
237 }
238
239 pub fn run(&mut self) -> Result<String> {
244 self.conn.execute(&self.sql)
245 }
246
247 pub fn query(&self) -> Result<Rows> {
250 match &self.kind {
251 StatementKind::Select(sq) => {
252 let result = execute_select_rows(sq.clone(), &self.conn.db)?;
253 Ok(Rows {
254 columns: result.columns,
255 rows: result.rows.into_iter(),
256 })
257 }
258 StatementKind::Other => Err(SQLRiteError::General(
259 "query() only works on SELECT statements; use run() for DDL/DML".to_string(),
260 )),
261 }
262 }
263
264 pub fn column_names(&self) -> Option<Vec<String>> {
267 match &self.kind {
268 StatementKind::Select(_) => {
269 None
274 }
275 StatementKind::Other => None,
276 }
277 }
278}
279
280pub struct Rows {
289 columns: Vec<String>,
290 rows: std::vec::IntoIter<Vec<Value>>,
291}
292
293impl std::fmt::Debug for Rows {
294 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
295 f.debug_struct("Rows")
296 .field("columns", &self.columns)
297 .field("remaining", &self.rows.len())
298 .finish()
299 }
300}
301
302impl Rows {
303 pub fn columns(&self) -> &[String] {
305 &self.columns
306 }
307
308 pub fn next(&mut self) -> Result<Option<Row<'_>>> {
313 Ok(self.rows.next().map(|values| Row {
314 columns: &self.columns,
315 values,
316 }))
317 }
318
319 pub fn collect_all(mut self) -> Result<Vec<OwnedRow>> {
323 let mut out = Vec::new();
324 while let Some(r) = self.next()? {
325 out.push(r.to_owned_row());
326 }
327 Ok(out)
328 }
329}
330
331pub struct Row<'r> {
335 columns: &'r [String],
336 values: Vec<Value>,
337}
338
339impl<'r> Row<'r> {
340 pub fn get<T: FromValue>(&self, idx: usize) -> Result<T> {
343 let v = self.values.get(idx).ok_or_else(|| {
344 SQLRiteError::General(format!(
345 "column index {idx} out of bounds (row has {} columns)",
346 self.values.len()
347 ))
348 })?;
349 T::from_value(v)
350 }
351
352 pub fn get_by_name<T: FromValue>(&self, name: &str) -> Result<T> {
354 let idx = self
355 .columns
356 .iter()
357 .position(|c| c == name)
358 .ok_or_else(|| SQLRiteError::General(format!("no column named '{name}' in row")))?;
359 self.get(idx)
360 }
361
362 pub fn columns(&self) -> &[String] {
364 self.columns
365 }
366
367 pub fn to_owned_row(&self) -> OwnedRow {
370 OwnedRow {
371 columns: self.columns.to_vec(),
372 values: self.values.clone(),
373 }
374 }
375}
376
377#[derive(Debug, Clone)]
380pub struct OwnedRow {
381 pub columns: Vec<String>,
382 pub values: Vec<Value>,
383}
384
385impl OwnedRow {
386 pub fn get<T: FromValue>(&self, idx: usize) -> Result<T> {
387 let v = self.values.get(idx).ok_or_else(|| {
388 SQLRiteError::General(format!(
389 "column index {idx} out of bounds (row has {} columns)",
390 self.values.len()
391 ))
392 })?;
393 T::from_value(v)
394 }
395
396 pub fn get_by_name<T: FromValue>(&self, name: &str) -> Result<T> {
397 let idx = self
398 .columns
399 .iter()
400 .position(|c| c == name)
401 .ok_or_else(|| SQLRiteError::General(format!("no column named '{name}' in row")))?;
402 self.get(idx)
403 }
404}
405
406pub trait FromValue: Sized {
411 fn from_value(v: &Value) -> Result<Self>;
412}
413
414impl FromValue for i64 {
415 fn from_value(v: &Value) -> Result<Self> {
416 match v {
417 Value::Integer(n) => Ok(*n),
418 Value::Null => Err(SQLRiteError::General(
419 "expected Integer, got NULL".to_string(),
420 )),
421 other => Err(SQLRiteError::General(format!(
422 "cannot convert {other:?} to i64"
423 ))),
424 }
425 }
426}
427
428impl FromValue for f64 {
429 fn from_value(v: &Value) -> Result<Self> {
430 match v {
431 Value::Real(f) => Ok(*f),
432 Value::Integer(n) => Ok(*n as f64),
433 Value::Null => Err(SQLRiteError::General("expected Real, got NULL".to_string())),
434 other => Err(SQLRiteError::General(format!(
435 "cannot convert {other:?} to f64"
436 ))),
437 }
438 }
439}
440
441impl FromValue for String {
442 fn from_value(v: &Value) -> Result<Self> {
443 match v {
444 Value::Text(s) => Ok(s.clone()),
445 Value::Null => Err(SQLRiteError::General("expected Text, got NULL".to_string())),
446 other => Err(SQLRiteError::General(format!(
447 "cannot convert {other:?} to String"
448 ))),
449 }
450 }
451}
452
453impl FromValue for bool {
454 fn from_value(v: &Value) -> Result<Self> {
455 match v {
456 Value::Bool(b) => Ok(*b),
457 Value::Integer(n) => Ok(*n != 0),
458 Value::Null => Err(SQLRiteError::General("expected Bool, got NULL".to_string())),
459 other => Err(SQLRiteError::General(format!(
460 "cannot convert {other:?} to bool"
461 ))),
462 }
463 }
464}
465
466impl<T: FromValue> FromValue for Option<T> {
469 fn from_value(v: &Value) -> Result<Self> {
470 match v {
471 Value::Null => Ok(None),
472 other => Ok(Some(T::from_value(other)?)),
473 }
474 }
475}
476
477impl FromValue for Value {
480 fn from_value(v: &Value) -> Result<Self> {
481 Ok(v.clone())
482 }
483}
484
485#[cfg(test)]
486mod tests {
487 use super::*;
488
489 fn tmp_path(name: &str) -> std::path::PathBuf {
490 let mut p = std::env::temp_dir();
491 let pid = std::process::id();
492 let nanos = std::time::SystemTime::now()
493 .duration_since(std::time::UNIX_EPOCH)
494 .map(|d| d.as_nanos())
495 .unwrap_or(0);
496 p.push(format!("sqlrite-conn-{pid}-{nanos}-{name}.sqlrite"));
497 p
498 }
499
500 fn cleanup(path: &std::path::Path) {
501 let _ = std::fs::remove_file(path);
502 let mut wal = path.as_os_str().to_owned();
503 wal.push("-wal");
504 let _ = std::fs::remove_file(std::path::PathBuf::from(wal));
505 }
506
507 #[test]
508 fn in_memory_roundtrip() {
509 let mut conn = Connection::open_in_memory().unwrap();
510 conn.execute("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT, age INTEGER);")
511 .unwrap();
512 conn.execute("INSERT INTO users (name, age) VALUES ('alice', 30);")
513 .unwrap();
514 conn.execute("INSERT INTO users (name, age) VALUES ('bob', 25);")
515 .unwrap();
516
517 let stmt = conn.prepare("SELECT id, name, age FROM users;").unwrap();
518 let mut rows = stmt.query().unwrap();
519 assert_eq!(rows.columns(), &["id", "name", "age"]);
520 let mut collected: Vec<(i64, String, i64)> = Vec::new();
521 while let Some(row) = rows.next().unwrap() {
522 collected.push((
523 row.get::<i64>(0).unwrap(),
524 row.get::<String>(1).unwrap(),
525 row.get::<i64>(2).unwrap(),
526 ));
527 }
528 assert_eq!(collected.len(), 2);
529 assert!(collected.iter().any(|(_, n, a)| n == "alice" && *a == 30));
530 assert!(collected.iter().any(|(_, n, a)| n == "bob" && *a == 25));
531 }
532
533 #[test]
534 fn file_backed_persists_across_connections() {
535 let path = tmp_path("persist");
536 {
537 let mut c1 = Connection::open(&path).unwrap();
538 c1.execute("CREATE TABLE items (id INTEGER PRIMARY KEY, label TEXT);")
539 .unwrap();
540 c1.execute("INSERT INTO items (label) VALUES ('one');")
541 .unwrap();
542 }
543 {
544 let mut c2 = Connection::open(&path).unwrap();
545 let stmt = c2.prepare("SELECT label FROM items;").unwrap();
546 let mut rows = stmt.query().unwrap();
547 let first = rows.next().unwrap().expect("one row");
548 assert_eq!(first.get::<String>(0).unwrap(), "one");
549 assert!(rows.next().unwrap().is_none());
550 }
551 cleanup(&path);
552 }
553
554 #[test]
555 fn read_only_connection_rejects_writes() {
556 let path = tmp_path("ro_reject");
557 {
558 let mut c = Connection::open(&path).unwrap();
559 c.execute("CREATE TABLE t (id INTEGER PRIMARY KEY);")
560 .unwrap();
561 c.execute("INSERT INTO t (id) VALUES (1);").unwrap();
562 } let mut ro = Connection::open_read_only(&path).unwrap();
565 assert!(ro.is_read_only());
566 let err = ro.execute("INSERT INTO t (id) VALUES (2);").unwrap_err();
567 assert!(format!("{err}").contains("read-only"));
568 cleanup(&path);
569 }
570
571 #[test]
572 fn transactions_work_through_connection() {
573 let mut conn = Connection::open_in_memory().unwrap();
574 conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY, x INTEGER);")
575 .unwrap();
576 conn.execute("INSERT INTO t (x) VALUES (1);").unwrap();
577
578 conn.execute("BEGIN;").unwrap();
579 assert!(conn.in_transaction());
580 conn.execute("INSERT INTO t (x) VALUES (2);").unwrap();
581 conn.execute("ROLLBACK;").unwrap();
582 assert!(!conn.in_transaction());
583
584 let stmt = conn.prepare("SELECT x FROM t;").unwrap();
585 let rows = stmt.query().unwrap().collect_all().unwrap();
586 assert_eq!(rows.len(), 1);
587 assert_eq!(rows[0].get::<i64>(0).unwrap(), 1);
588 }
589
590 #[test]
591 fn get_by_name_works() {
592 let mut conn = Connection::open_in_memory().unwrap();
593 conn.execute("CREATE TABLE t (a INTEGER, b TEXT);").unwrap();
594 conn.execute("INSERT INTO t (a, b) VALUES (42, 'hello');")
595 .unwrap();
596
597 let stmt = conn.prepare("SELECT a, b FROM t;").unwrap();
598 let mut rows = stmt.query().unwrap();
599 let row = rows.next().unwrap().unwrap();
600 assert_eq!(row.get_by_name::<i64>("a").unwrap(), 42);
601 assert_eq!(row.get_by_name::<String>("b").unwrap(), "hello");
602 }
603
604 #[test]
605 fn null_column_maps_to_none() {
606 let mut conn = Connection::open_in_memory().unwrap();
607 conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY, note TEXT);")
608 .unwrap();
609 conn.execute("INSERT INTO t (id) VALUES (1);").unwrap();
611
612 let stmt = conn.prepare("SELECT id, note FROM t;").unwrap();
613 let mut rows = stmt.query().unwrap();
614 let row = rows.next().unwrap().unwrap();
615 assert_eq!(row.get::<i64>(0).unwrap(), 1);
616 assert_eq!(row.get::<Option<String>>(1).unwrap(), None);
618 }
619
620 #[test]
621 fn prepare_rejects_multiple_statements() {
622 let mut conn = Connection::open_in_memory().unwrap();
623 let err = conn.prepare("SELECT 1; SELECT 2;").unwrap_err();
624 assert!(format!("{err}").contains("single statement"));
625 }
626
627 #[test]
628 fn query_on_non_select_errors() {
629 let mut conn = Connection::open_in_memory().unwrap();
630 conn.execute("CREATE TABLE t (id INTEGER PRIMARY KEY);")
631 .unwrap();
632 let stmt = conn.prepare("INSERT INTO t VALUES (1);").unwrap();
633 let err = stmt.query().unwrap_err();
634 assert!(format!("{err}").contains("SELECT"));
635 }
636
637 #[test]
638 fn index_out_of_bounds_errors_cleanly() {
639 let mut conn = Connection::open_in_memory().unwrap();
640 conn.execute("CREATE TABLE t (a INTEGER PRIMARY KEY);")
641 .unwrap();
642 conn.execute("INSERT INTO t (a) VALUES (1);").unwrap();
643 let stmt = conn.prepare("SELECT a FROM t;").unwrap();
644 let mut rows = stmt.query().unwrap();
645 let row = rows.next().unwrap().unwrap();
646 let err = row.get::<i64>(99).unwrap_err();
647 assert!(format!("{err}").contains("out of bounds"));
648 }
649}