use crate::error::{Result, SQLRiteError};
use crate::sql::db::table::Table;
use crate::sql::pager::pager::{AccessMode, Pager};
use std::collections::HashMap;
use std::path::PathBuf;
#[derive(Debug)]
pub struct TxnSnapshot {
pub(crate) tables: HashMap<String, Table>,
}
#[derive(Debug)]
pub struct Database {
pub db_name: String,
pub tables: HashMap<String, Table>,
pub source_path: Option<PathBuf>,
pub pager: Option<Pager>,
pub txn: Option<TxnSnapshot>,
}
impl Database {
pub fn new(db_name: String) -> Self {
Database {
db_name,
tables: HashMap::new(),
source_path: None,
pager: None,
txn: None,
}
}
pub fn contains_table(&self, table_name: String) -> bool {
self.tables.contains_key(&table_name)
}
pub fn get_table(&self, table_name: String) -> Result<&Table> {
if let Some(table) = self.tables.get(&table_name) {
Ok(table)
} else {
Err(SQLRiteError::General(String::from("Table not found.")))
}
}
pub fn get_table_mut(&mut self, table_name: String) -> Result<&mut Table> {
if let Some(table) = self.tables.get_mut(&table_name) {
Ok(table)
} else {
Err(SQLRiteError::General(String::from("Table not found.")))
}
}
pub fn is_read_only(&self) -> bool {
self.pager
.as_ref()
.is_some_and(|p| p.access_mode() == AccessMode::ReadOnly)
}
pub fn in_transaction(&self) -> bool {
self.txn.is_some()
}
pub fn begin_transaction(&mut self) -> Result<()> {
if self.in_transaction() {
return Err(SQLRiteError::General(
"cannot BEGIN: a transaction is already open".to_string(),
));
}
if self.is_read_only() {
return Err(SQLRiteError::General(
"cannot BEGIN: database is opened read-only".to_string(),
));
}
let snapshot = TxnSnapshot {
tables: self
.tables
.iter()
.map(|(k, v)| (k.clone(), v.deep_clone()))
.collect(),
};
self.txn = Some(snapshot);
Ok(())
}
pub fn commit_transaction(&mut self) -> Result<()> {
if self.txn.is_none() {
return Err(SQLRiteError::General(
"cannot COMMIT: no transaction is open".to_string(),
));
}
self.txn = None;
Ok(())
}
pub fn rollback_transaction(&mut self) -> Result<()> {
let Some(snapshot) = self.txn.take() else {
return Err(SQLRiteError::General(
"cannot ROLLBACK: no transaction is open".to_string(),
));
};
self.tables = snapshot.tables;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sql::parser::create::CreateQuery;
use sqlparser::dialect::SQLiteDialect;
use sqlparser::parser::Parser;
#[test]
fn new_database_create_test() {
let db_name = String::from("my_db");
let db = Database::new(db_name.to_string());
assert_eq!(db.db_name, db_name);
}
#[test]
fn contains_table_test() {
let db_name = String::from("my_db");
let mut db = Database::new(db_name.to_string());
let query_statement = "CREATE TABLE contacts (
id INTEGER PRIMARY KEY,
first_name TEXT NOT NULL,
last_name TEXT NOT NULl,
email TEXT NOT NULL UNIQUE
);";
let dialect = SQLiteDialect {};
let mut ast = Parser::parse_sql(&dialect, query_statement).unwrap();
if ast.len() > 1 {
panic!("Expected a single query statement, but there are more then 1.")
}
let query = ast.pop().unwrap();
let create_query = CreateQuery::new(&query).unwrap();
let table_name = &create_query.table_name;
db.tables
.insert(table_name.to_string(), Table::new(create_query));
assert!(db.contains_table("contacts".to_string()));
}
#[test]
fn get_table_test() {
let db_name = String::from("my_db");
let mut db = Database::new(db_name.to_string());
let query_statement = "CREATE TABLE contacts (
id INTEGER PRIMARY KEY,
first_name TEXT NOT NULL,
last_name TEXT NOT NULl,
email TEXT NOT NULL UNIQUE
);";
let dialect = SQLiteDialect {};
let mut ast = Parser::parse_sql(&dialect, query_statement).unwrap();
if ast.len() > 1 {
panic!("Expected a single query statement, but there are more then 1.")
}
let query = ast.pop().unwrap();
let create_query = CreateQuery::new(&query).unwrap();
let table_name = &create_query.table_name;
db.tables
.insert(table_name.to_string(), Table::new(create_query));
let table = db.get_table(String::from("contacts")).unwrap();
assert_eq!(table.columns.len(), 4);
let table = db.get_table_mut(String::from("contacts")).unwrap();
table.last_rowid += 1;
assert_eq!(table.columns.len(), 4);
assert_eq!(table.last_rowid, 1);
}
}