Skip to main content

sqlrite/sql/db/
database.rs

1use crate::error::{Result, SQLRiteError};
2use crate::sql::db::table::Table;
3use crate::sql::pager::pager::{AccessMode, Pager};
4use std::collections::HashMap;
5use std::path::PathBuf;
6
7/// Snapshot of the mutable in-memory state taken at `BEGIN` time so
8/// `ROLLBACK` can restore it. See `begin_transaction`, `rollback_transaction`.
9/// `tables` is deep-cloned (the `Table::deep_clone` helper reallocates
10/// the `Arc<Mutex<_>>` row storage so snapshot and live state don't
11/// share a map).
12#[derive(Debug)]
13pub struct TxnSnapshot {
14    pub(crate) tables: HashMap<String, Table>,
15}
16
17/// Default fraction of free pages that triggers an auto-VACUUM after
18/// a page-releasing DDL (DROP TABLE / DROP INDEX / ALTER TABLE DROP
19/// COLUMN). Matches SQLite's classic 25% heuristic. Override per
20/// connection with [`Database::set_auto_vacuum_threshold`] (or
21/// `Connection::set_auto_vacuum_threshold`); pass `None` to disable.
22pub const DEFAULT_AUTO_VACUUM_THRESHOLD: f32 = 0.25;
23
24/// The database is represented by this structure.assert_eq!
25#[derive(Debug)]
26pub struct Database {
27    /// Name of this database. (schema name, not filename)
28    pub db_name: String,
29    /// HashMap of tables in this database
30    pub tables: HashMap<String, Table>,
31    /// If `Some`, every committing SQL statement auto-flushes the DB to
32    /// this path. `None` → transient in-memory mode (the default; the
33    /// REPL only enters persistent mode after `.open FILE`).
34    pub source_path: Option<PathBuf>,
35    /// Long-lived pager attached when the database is file-backed. Keeps
36    /// an in-memory snapshot of every page so auto-saves can diff
37    /// against the last-committed state and skip rewriting unchanged
38    /// pages. `None` means "in-memory only" or "not yet opened".
39    pub pager: Option<Pager>,
40    /// Active transaction state (Phase 4f). `Some` between `BEGIN` and
41    /// the matching `COMMIT` / `ROLLBACK`. While set:
42    /// - auto-save is suppressed (mutations stay in-memory)
43    /// - nested `BEGIN` is rejected
44    /// - `ROLLBACK` restores `tables` from the snapshot
45    pub txn: Option<TxnSnapshot>,
46    /// Auto-VACUUM trigger (SQLR-10). After a page-releasing DDL
47    /// (DROP TABLE / DROP INDEX / ALTER TABLE DROP COLUMN) commits and
48    /// flushes, if the freelist exceeds this fraction of `page_count`
49    /// the engine quietly compacts the file. `None` disables the
50    /// trigger; defaults to `Some(DEFAULT_AUTO_VACUUM_THRESHOLD)`
51    /// (SQLite parity at 25%). Per-connection runtime state — not
52    /// persisted across reopens.
53    pub auto_vacuum_threshold: Option<f32>,
54}
55
56impl Database {
57    /// Creates an empty in-memory `Database`.
58    ///
59    /// # Examples
60    ///
61    /// ```
62    /// use sqlrite::Database;
63    /// let mut db = Database::new("my_db".to_string());
64    /// ```
65    pub fn new(db_name: String) -> Self {
66        Database {
67            db_name,
68            tables: HashMap::new(),
69            source_path: None,
70            pager: None,
71            txn: None,
72            auto_vacuum_threshold: Some(DEFAULT_AUTO_VACUUM_THRESHOLD),
73        }
74    }
75
76    /// Returns the current auto-VACUUM threshold, or `None` if disabled.
77    /// See [`Database::set_auto_vacuum_threshold`] for semantics.
78    pub fn auto_vacuum_threshold(&self) -> Option<f32> {
79        self.auto_vacuum_threshold
80    }
81
82    /// Sets the auto-VACUUM threshold (SQLR-10). `Some(t)` with `t` in
83    /// `0.0..=1.0` arms the trigger: after a page-releasing DDL
84    /// commits, if the freelist exceeds `t * page_count` the engine
85    /// runs a full-file compact. `None` disables the trigger. Values
86    /// outside `0.0..=1.0` (or NaN / infinite) return a typed error
87    /// rather than silently saturating.
88    pub fn set_auto_vacuum_threshold(&mut self, threshold: Option<f32>) -> Result<()> {
89        if let Some(t) = threshold {
90            if !t.is_finite() || !(0.0..=1.0).contains(&t) {
91                return Err(SQLRiteError::General(format!(
92                    "auto_vacuum_threshold must be in 0.0..=1.0, got {t}"
93                )));
94            }
95        }
96        self.auto_vacuum_threshold = threshold;
97        Ok(())
98    }
99
100    /// Returns true if the database contains a table with the specified key as a table name.
101    ///
102    pub fn contains_table(&self, table_name: String) -> bool {
103        self.tables.contains_key(&table_name)
104    }
105
106    /// Returns an immutable reference of `sql::db::table::Table` if the database contains a
107    /// table with the specified key as a table name.
108    ///
109    pub fn get_table(&self, table_name: String) -> Result<&Table> {
110        if let Some(table) = self.tables.get(&table_name) {
111            Ok(table)
112        } else {
113            Err(SQLRiteError::General(String::from("Table not found.")))
114        }
115    }
116
117    /// Returns an mutable reference of `sql::db::table::Table` if the database contains a
118    /// table with the specified key as a table name.
119    ///
120    pub fn get_table_mut(&mut self, table_name: String) -> Result<&mut Table> {
121        if let Some(table) = self.tables.get_mut(&table_name) {
122            Ok(table)
123        } else {
124            Err(SQLRiteError::General(String::from("Table not found.")))
125        }
126    }
127
128    /// Returns `true` if this database is attached to a file and that
129    /// file was opened in [`AccessMode::ReadOnly`]. In-memory databases
130    /// (no pager) and read-write file-backed databases both return
131    /// `false`. Callers use this to reject mutating SQL at the
132    /// dispatcher level so the in-memory tables don't drift away from
133    /// disk on a would-be INSERT / UPDATE / DELETE.
134    pub fn is_read_only(&self) -> bool {
135        self.pager
136            .as_ref()
137            .is_some_and(|p| p.access_mode() == AccessMode::ReadOnly)
138    }
139
140    /// Returns `true` while a `BEGIN … COMMIT`/`ROLLBACK` block is open.
141    pub fn in_transaction(&self) -> bool {
142        self.txn.is_some()
143    }
144
145    /// Starts a transaction: snapshots every table deep-cloned so that
146    /// a later `rollback_transaction` can restore the pre-BEGIN state.
147    /// Nested transactions are rejected — explicit savepoints are not
148    /// on this phase's roadmap. Errors on a read-only database.
149    pub fn begin_transaction(&mut self) -> Result<()> {
150        if self.in_transaction() {
151            return Err(SQLRiteError::General(
152                "cannot BEGIN: a transaction is already open".to_string(),
153            ));
154        }
155        if self.is_read_only() {
156            return Err(SQLRiteError::General(
157                "cannot BEGIN: database is opened read-only".to_string(),
158            ));
159        }
160        let snapshot = TxnSnapshot {
161            tables: self
162                .tables
163                .iter()
164                .map(|(k, v)| (k.clone(), v.deep_clone()))
165                .collect(),
166        };
167        self.txn = Some(snapshot);
168        Ok(())
169    }
170
171    /// Drops the transaction snapshot and returns it for the caller to
172    /// discard. The in-memory `tables` state is the new committed state;
173    /// the caller is responsible for flushing to disk via the pager.
174    /// Errors if no transaction is open.
175    pub fn commit_transaction(&mut self) -> Result<()> {
176        if self.txn.is_none() {
177            return Err(SQLRiteError::General(
178                "cannot COMMIT: no transaction is open".to_string(),
179            ));
180        }
181        self.txn = None;
182        Ok(())
183    }
184
185    /// Restores `tables` from the transaction snapshot and clears it.
186    /// Errors if no transaction is open.
187    pub fn rollback_transaction(&mut self) -> Result<()> {
188        let Some(snapshot) = self.txn.take() else {
189            return Err(SQLRiteError::General(
190                "cannot ROLLBACK: no transaction is open".to_string(),
191            ));
192        };
193        self.tables = snapshot.tables;
194        Ok(())
195    }
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201    use crate::sql::parser::create::CreateQuery;
202    use sqlparser::dialect::SQLiteDialect;
203    use sqlparser::parser::Parser;
204
205    #[test]
206    fn new_database_create_test() {
207        let db_name = String::from("my_db");
208        let db = Database::new(db_name.to_string());
209        assert_eq!(db.db_name, db_name);
210    }
211
212    #[test]
213    fn contains_table_test() {
214        let db_name = String::from("my_db");
215        let mut db = Database::new(db_name.to_string());
216
217        let query_statement = "CREATE TABLE contacts (
218            id INTEGER PRIMARY KEY,
219            first_name TEXT NOT NULL,
220            last_name TEXT NOT NULl,
221            email TEXT NOT NULL UNIQUE
222        );";
223        let dialect = SQLiteDialect {};
224        let mut ast = Parser::parse_sql(&dialect, query_statement).unwrap();
225        if ast.len() > 1 {
226            panic!("Expected a single query statement, but there are more then 1.")
227        }
228        let query = ast.pop().unwrap();
229
230        let create_query = CreateQuery::new(&query).unwrap();
231        let table_name = &create_query.table_name;
232        db.tables
233            .insert(table_name.to_string(), Table::new(create_query));
234
235        assert!(db.contains_table("contacts".to_string()));
236    }
237
238    #[test]
239    fn get_table_test() {
240        let db_name = String::from("my_db");
241        let mut db = Database::new(db_name.to_string());
242
243        let query_statement = "CREATE TABLE contacts (
244            id INTEGER PRIMARY KEY,
245            first_name TEXT NOT NULL,
246            last_name TEXT NOT NULl,
247            email TEXT NOT NULL UNIQUE
248        );";
249        let dialect = SQLiteDialect {};
250        let mut ast = Parser::parse_sql(&dialect, query_statement).unwrap();
251        if ast.len() > 1 {
252            panic!("Expected a single query statement, but there are more then 1.")
253        }
254        let query = ast.pop().unwrap();
255
256        let create_query = CreateQuery::new(&query).unwrap();
257        let table_name = &create_query.table_name;
258        db.tables
259            .insert(table_name.to_string(), Table::new(create_query));
260
261        let table = db.get_table(String::from("contacts")).unwrap();
262        assert_eq!(table.columns.len(), 4);
263
264        let table = db.get_table_mut(String::from("contacts")).unwrap();
265        table.last_rowid += 1;
266        assert_eq!(table.columns.len(), 4);
267        assert_eq!(table.last_rowid, 1);
268    }
269}