use std::path::Path;
use std::sync::{Arc, Mutex, MutexGuard};
use rusqlite::Connection;
use crate::database::error::DbError;
use crate::database::record::DbRecord;
use crate::database::repository;
#[derive(Clone)]
pub struct DataBase {
inner: Arc<Mutex<Connection>>,
}
impl DataBase {
pub fn open(path: &Path) -> Result<Self, DbError> {
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent).ok(); }
let conn: Connection = Connection::open(path)?;
configure(&conn)?;
Ok(Self {
inner: Arc::new(Mutex::new(conn)),
})
}
pub fn in_memory() -> Result<Self, DbError> {
let conn: Connection = Connection::open_in_memory()?;
configure(&conn)?;
Ok(Self {
inner: Arc::new(Mutex::new(conn)),
})
}
pub fn register<T: DbRecord>(&self) -> Result<(), DbError> {
self.ensureTable::<T>()
}
pub fn getRepository<T: DbRecord>(&self) -> repository::Repository<T> {
repository::Repository::attached(self.clone())
}
pub fn ensureRepository<T: DbRecord>(&self) -> Result<repository::Repository<T>, DbError> {
repository::Repository::new(self.clone())
}
pub(crate) fn lock(&self) -> MutexGuard<'_, Connection> {
self.inner.lock().expect("DB mutex poisoned")
}
pub(crate) fn ensureTable<T: DbRecord>(&self) -> Result<(), DbError> {
validateIdentifier(T::table_name())?;
let createTableQuery: String = generate_create_table_query::<T>()?;
let createIndexesQuery: Vec<String> = generate_create_indexes_query::<T>()?;
let conn: MutexGuard<'_, Connection> = self.lock();
conn.execute_batch(&format!(
"BEGIN;\n{}\n{}\nCOMMIT;",
createTableQuery,
createIndexesQuery.join("\n")
))?;
Ok(())
}
}
fn configure(conn: &Connection) -> Result<(), DbError> {
conn.pragma_update(None, "journal_mode", "WAL")?;
conn.pragma_update(None, "foreign_keys", true)?;
conn.pragma_update(None, "synchronous", "NORMAL")?;
Ok(())
}
pub(crate) fn quoteIdentifier(name: &str) -> String {
format!(r#""{}""#, name.replace('"', r#""""#))
}
pub(crate) fn validateIdentifier(name: &str) -> Result<(), DbError> {
let ok: bool = !name.is_empty()
&&
name
.chars()
.next()
.map(|c| c.is_ascii_alphabetic() || c == '_')
.unwrap_or(false)
&& name.chars().all(|c| c.is_ascii_alphanumeric() || c == '_');
if ok {
Ok(())
} else {
Err(DbError::InvalidIdentifier(name.to_string()))
}
}
fn generate_create_table_query<T: DbRecord>() -> Result<String, DbError> {
let table: String = quoteIdentifier(T::table_name());
let mut parts: Vec<String> = vec![format!(
r#"{} INTEGER PRIMARY KEY AUTOINCREMENT"#,
quoteIdentifier("id")
)];
for col in T::columns() {
validateIdentifier(col.name)?;
parts.push(col.to_sql_fragment());
}
Ok(format!(
"CREATE TABLE IF NOT EXISTS {} (\n {}\n);",
table,
parts.join(",\n ")
))
}
fn generate_create_indexes_query<T: DbRecord>() -> Result<Vec<String>, DbError> {
let table: &str = T::table_name();
let mut sqls: Vec<String> = Vec::new();
for idx in T::indexes() {
for col in idx.columns {
validateIdentifier(col)?;
}
let col_slug: String = idx.columns.join("_");
let index_name: String = if idx.unique {
format!("uidx_{}_{}", table, col_slug)
} else {
format!("idx_{}_{}", table, col_slug)
};
let cols_sql: String = idx
.columns
.iter()
.map(|c| quoteIdentifier(c))
.collect::<Vec<_>>()
.join(", ");
let unique: &str = if idx.unique { "UNIQUE " } else { "" };
sqls.push(format!(
"CREATE {unique}INDEX IF NOT EXISTS {} ON {} ({cols_sql});",
quoteIdentifier(&index_name),
quoteIdentifier(table),
));
}
Ok(sqls)
}
pub(crate) fn row_to_valueset(
row: &rusqlite::Row<'_>,
column_names: &[&'static str],
) -> rusqlite::Result<crate::database::record::ValueSet> {
let nbColumns: usize = 1 + column_names.len();
let raw: rusqlite::Result<Vec<rusqlite::types::Value>> = (0..nbColumns)
.map(|i: usize| row.get::<_, rusqlite::types::Value>(i))
.collect();
raw.map(|vals: Vec<rusqlite::types::Value>| {
crate::database::record::ValueSet::new(
vals.into_iter()
.map(crate::database::value::from_rusqlite)
.collect(),
column_names.to_vec(),
)
})
}
pub(crate) fn generate_select_columns_sql<T: DbRecord>() -> String {
let mut columns: Vec<String> = vec![quoteIdentifier("id")];
for col in T::columns() {
columns.push(quoteIdentifier(col.name));
}
columns.join(", ")
}
pub(crate) fn column_names<T: DbRecord>() -> Vec<&'static str> {
T::columns().iter().map(|c: &super::Column| c.name).collect()
}