use std::{
cell::RefCell, convert::Infallible, marker::PhantomData, rc::Rc, sync::atomic::AtomicI64,
};
use rusqlite::ErrorCode;
use self_cell::{MutBorrow, self_cell};
use crate::{
IntoExpr, IntoSelect, Table, TableRow,
error::FromConflict,
lower::{
self, emit,
list_writer::{Alias, ListWriter},
ord_rc::OrdRc,
},
migrate::{Schema, check_schema, schema_version, user_version},
migration::Config,
mutable::Mutable,
pool::Pool,
private::{IntoJoinable, Reader},
query::{OwnedRows, Query, track_stmt},
rows::Rows,
value::{DbTyp, OptTable},
};
pub struct Database<S> {
pub(crate) pool: Pool,
pub(crate) schema_version: AtomicI64,
pub(crate) schema: PhantomData<S>,
pub(crate) mut_lock: parking_lot::FairMutex<()>,
}
impl<S: Schema> Database<S> {
pub fn new(config: Config) -> Self {
let Some(m) = Self::migrator(config) else {
panic!("schema version {}, but got an older version", S::VERSION)
};
let Some(m) = m.finish() else {
panic!("schema version {}, but got a new version", S::VERSION)
};
m
}
}
use rusqlite::Connection;
type RTransaction<'x> = Option<rusqlite::Transaction<'x>>;
self_cell!(
pub struct OwnedTransaction {
owner: MutBorrow<Connection>,
#[covariant]
dependent: RTransaction,
}
);
unsafe impl Send for OwnedTransaction {}
assert_not_impl_any! {OwnedTransaction: Sync}
thread_local! {
pub(crate) static TXN: RefCell<Option<TransactionWithRows>> = const { RefCell::new(None) };
}
impl OwnedTransaction {
pub(crate) fn get(&self) -> &rusqlite::Transaction<'_> {
self.borrow_dependent().as_ref().unwrap()
}
pub(crate) fn with(
mut self,
f: impl FnOnce(rusqlite::Transaction<'_>),
) -> rusqlite::Connection {
self.with_dependent_mut(|_, b| f(b.take().unwrap()));
self.into_owner().into_inner()
}
}
type OwnedRowsVec<'x> = slab::Slab<OwnedRows<'x>>;
self_cell!(
pub struct TransactionWithRows {
owner: OwnedTransaction,
#[not_covariant]
dependent: OwnedRowsVec,
}
);
impl TransactionWithRows {
pub(crate) fn new_empty(txn: OwnedTransaction) -> Self {
Self::new(txn, |_| slab::Slab::new())
}
pub(crate) fn get(&self) -> &rusqlite::Transaction<'_> {
self.borrow_owner().get()
}
}
impl<S: Send + Sync + Schema> Database<S> {
#[doc = include_str!("database/transaction.md")]
pub fn transaction<R: Send>(&self, f: impl Send + FnOnce(&'static Transaction<S>) -> R) -> R {
let res = std::thread::scope(|scope| scope.spawn(|| self.transaction_local(f)).join());
match res {
Ok(val) => val,
Err(payload) => std::panic::resume_unwind(payload),
}
}
pub(crate) fn transaction_local<R>(&self, f: impl FnOnce(&'static Transaction<S>) -> R) -> R {
let conn = self.pool.pop();
let owned = OwnedTransaction::new(MutBorrow::new(conn), |conn| {
Some(conn.borrow_mut().transaction().unwrap())
});
let res = f(Transaction::new_checked(owned, &self.schema_version));
let owned = TXN.take().unwrap().into_owner();
self.pool.push(owned.into_owner().into_inner());
res
}
#[doc = include_str!("database/transaction_mut.md")]
pub fn transaction_mut<O: Send, E: Send>(
&self,
f: impl Send + FnOnce(&'static mut Transaction<S>) -> Result<O, E>,
) -> Result<O, E> {
let join_res =
std::thread::scope(|scope| scope.spawn(|| self.transaction_mut_local(f)).join());
match join_res {
Ok(val) => val,
Err(payload) => std::panic::resume_unwind(payload),
}
}
pub(crate) fn transaction_mut_local<O, E>(
&self,
f: impl FnOnce(&'static mut Transaction<S>) -> Result<O, E>,
) -> Result<O, E> {
let guard = self.mut_lock.lock();
let conn = self.pool.pop();
let owned = OwnedTransaction::new(MutBorrow::new(conn), |conn| {
let txn = conn
.borrow_mut()
.transaction_with_behavior(rusqlite::TransactionBehavior::Immediate)
.unwrap();
Some(txn)
});
let res = f(Transaction::new_checked(owned, &self.schema_version));
drop(guard);
let owned = TXN.take().unwrap().into_owner();
let conn = if res.is_ok() {
owned.with(|x| x.commit().unwrap())
} else {
owned.with(|x| x.rollback().unwrap())
};
self.pool.push(conn);
res
}
#[doc = include_str!("database/transaction_mut_ok.md")]
pub fn transaction_mut_ok<R: Send>(
&self,
f: impl Send + FnOnce(&'static mut Transaction<S>) -> R,
) -> R {
self.transaction_mut(|txn| Ok::<R, Infallible>(f(txn)))
.unwrap()
}
pub fn rusqlite_connection(&self) -> rusqlite::Connection {
let conn = self.pool.pop();
conn.pragma_update(None, "foreign_keys", "ON").unwrap();
conn
}
}
pub struct Transaction<S> {
pub(crate) _p2: PhantomData<S>,
pub(crate) _local: PhantomData<*const ()>,
}
impl<S> Transaction<S> {
pub(crate) fn new() -> Self {
Self {
_p2: PhantomData,
_local: PhantomData,
}
}
pub(crate) fn copy(&self) -> Self {
Self::new()
}
pub(crate) fn new_ref() -> &'static mut Self {
Box::leak(Box::new(Self::new()))
}
}
impl<S: Schema> Transaction<S> {
pub(crate) fn new_checked(txn: OwnedTransaction, expected: &AtomicI64) -> &'static mut Self {
let schema_version = schema_version(txn.get());
if schema_version != expected.load(std::sync::atomic::Ordering::Relaxed) {
if user_version(txn.get()).unwrap() != S::VERSION {
panic!("The database user_version changed unexpectedly")
}
TXN.set(Some(TransactionWithRows::new_empty(txn)));
check_schema::<S>(Self::new_ref(), false);
expected.store(schema_version, std::sync::atomic::Ordering::Relaxed);
} else {
TXN.set(Some(TransactionWithRows::new_empty(txn)));
}
const {
assert!(size_of::<Self>() == 0);
}
Self::new_ref()
}
}
impl<S> Transaction<S> {
pub fn query<'t, R>(&'t self, f: impl FnOnce(&mut Query<'t, '_, S>) -> R) -> R {
let q = Rows {
phantom: PhantomData,
ast: Default::default(),
_p: PhantomData,
};
f(&mut Query {
q,
phantom: PhantomData,
})
}
pub fn query_one<O: 'static>(&self, val: impl IntoSelect<'static, S, Out = O>) -> O {
let mut query = self.query(|e| e.into_iter(val.into_select()));
let res = query.next().unwrap();
debug_assert!(query.next().is_none(), "query should return one row");
res
}
pub fn lazy<'t, T: OptTable<Schema = S>>(
&'t self,
val: impl IntoExpr<'static, S, Typ = T>,
) -> T::Lazy<'t> {
T::out_to_lazy(self.query_one(val.into_expr()))
}
pub fn lazy_iter<'t, T: Table<Schema = S>>(
&'t self,
val: impl IntoJoinable<'static, S, Typ = TableRow<T>>,
) -> LazyIter<'t, T> {
let val = val.into_joinable();
self.query(|rows| {
let table = rows.join(val);
LazyIter {
txn: self,
iter: rows.into_iter(table),
}
})
}
pub fn mutable<'t, T: OptTable<Schema = S>>(
&'t mut self,
val: impl IntoExpr<'static, S, Typ = T>,
) -> T::Mutable<'t> {
let x = self.query_one(T::select_opt_mutable(val.into_expr()));
T::into_mutable(x)
}
pub fn mutable_vec<'t, T: Table<Schema = S>>(
&'t mut self,
val: impl IntoJoinable<'static, S, Typ = TableRow<T>>,
) -> Vec<Mutable<'t, T>> {
let val = val.into_joinable();
self.query(|rows| {
let val = rows.join(val);
rows.into_vec((T::into_select(val.clone()), val))
.into_iter()
.map(TableRow::<T>::into_mutable)
.collect()
})
}
}
pub struct LazyIter<'t, T: Table> {
txn: &'t Transaction<T::Schema>,
iter: crate::query::Iter<'t, TableRow<T>>,
}
impl<'t, T: Table> Iterator for LazyIter<'t, T> {
type Item = crate::Lazy<'t, T>;
fn next(&mut self) -> Option<Self::Item> {
self.iter.next().map(|x| self.txn.lazy(x))
}
}
impl<S: 'static> Transaction<S> {
pub fn insert<T: Table<Schema = S>>(&mut self, val: T) -> Result<TableRow<T>, T::Conflict> {
try_insert_private(lower::JoinableTable::Table(T::NAME), None, val)
}
pub fn insert_ok<T: Table<Schema = S, Conflict = Infallible>>(
&mut self,
val: T,
) -> TableRow<T> {
let Ok(row) = self.insert(val);
row
}
pub fn find_or_insert<T: Table<Schema = S, Conflict = TableRow<T>>>(
&mut self,
val: T,
) -> TableRow<T> {
match self.insert(val) {
Ok(row) => row,
Err(row) => row,
}
}
pub(crate) fn update<T: Table<Schema = S>>(
&mut self,
row: TableRow<T>,
val: T::Mutable,
) -> Result<(), T::Conflict> {
let val = T::mutable_into_insert(val);
let mut reader = Reader::default();
T::read(&val, &mut reader);
let mut stmt = emit::Stmt::default();
stmt.write("UPDATE ");
lower::JoinableTable::Table(T::NAME).emit(&mut stmt);
stmt.write(" SET ");
let mut list = ListWriter::new(&mut stmt, ", ");
for (key, val) in &reader.builder {
list.item()
.write(format_args!("{} = ", Alias(key)))
.write_param(&val);
}
list.default(format_args!("{1} = {0}.{1}", Alias(T::NAME), Alias(T::ID)));
stmt.write(format_args!(
" WHERE {}.{} = ",
Alias(T::NAME),
Alias(T::ID)
));
stmt.write_param(&OrdRc(Rc::new(row.inner.idx.into())));
let res = TXN.with_borrow(|txn| {
let txn = txn.as_ref().unwrap().get();
let mut cached = txn.prepare_cached(&stmt.sql).unwrap();
cached.execute(rusqlite::params_from_iter(stmt.params))
});
match res {
Ok(1) => Ok(()),
Ok(n) => panic!("unexpected number of updates: {n}"),
Err(rusqlite::Error::SqliteFailure(kind, Some(msg)))
if kind.code == ErrorCode::ConstraintViolation =>
{
let res = TXN.with_borrow(|txn| {
let txn = txn.as_ref().unwrap().get();
<T::Conflict as FromConflict>::from_conflict(
txn,
lower::JoinableTable::Table(T::NAME),
reader.builder,
msg,
)
});
Err(res)
}
Err(err) => panic!("{err:?}"),
}
}
pub fn downgrade(&'static mut self) -> &'static mut TransactionWeak<S> {
Box::leak(Box::new(TransactionWeak { inner: PhantomData }))
}
}
pub struct TransactionWeak<S> {
inner: PhantomData<Transaction<S>>,
}
impl<S: Schema> TransactionWeak<S> {
pub fn delete<T: Table<Schema = S>>(&mut self, val: TableRow<T>) -> Result<bool, T::Referer> {
let schema = crate::schema::from_macro::Schema::new::<S>();
let mut checks = vec![];
for (&table_name, table) in &schema.tables {
for col in table.columns.iter().filter_map(|(col_name, col)| {
let col = &col.def;
col.fk
.as_ref()
.is_some_and(|(t, c)| t == T::NAME && c == T::ID)
.then_some(col_name)
}) {
let mut stmt = emit::Stmt::default();
stmt.write("SELECT ");
stmt.write_param(&OrdRc(Rc::new(val.inner.idx.into())));
stmt.write(format_args!(
" IN (SELECT {0}.{1} FROM {0})",
Alias(table_name),
Alias(col)
));
checks.push(stmt);
}
}
let mut stmt = emit::Stmt::default();
stmt.write(format_args!(
"DELETE FROM {0} WHERE {0}.{1} = ",
Alias(T::NAME),
Alias(T::ID)
));
stmt.write_param(&OrdRc::new(val.inner.idx));
TXN.with_borrow(|txn| {
let txn = txn.as_ref().unwrap().get();
for stmt in checks {
let mut cached = txn.prepare_cached(&stmt.sql).unwrap();
match cached.query_one(rusqlite::params_from_iter(stmt.params), |r| r.get(0)) {
Ok(true) => return Err(T::get_referer_unchecked()),
Ok(false) => {}
Err(err) => panic!("{err:?}"),
}
}
let mut cached = txn.prepare_cached(&stmt.sql).unwrap();
match cached.execute(rusqlite::params_from_iter(stmt.params)) {
Ok(0) => Ok(false),
Ok(1) => Ok(true),
Ok(n) => {
panic!("unexpected number of deletes {n}")
}
Err(err) => panic!("{err:?}"),
}
})
}
pub fn delete_ok<T: Table<Referer = Infallible, Schema = S>>(
&mut self,
val: TableRow<T>,
) -> bool {
let Ok(res) = self.delete(val);
res
}
pub fn rusqlite_transaction<R>(&mut self, f: impl FnOnce(&rusqlite::Transaction) -> R) -> R {
TXN.with_borrow(|txn| f(txn.as_ref().unwrap().get()))
}
}
pub fn try_insert_private<T: Table>(
table: lower::JoinableTable,
idx: Option<i64>,
val: T,
) -> Result<TableRow<T>, T::Conflict> {
let mut reader = Reader::default();
T::read(&val, &mut reader);
if let Some(idx) = idx {
reader.col::<i64>(T::ID, idx);
}
let mut stmt = emit::Stmt::default();
stmt.write("INSERT INTO ");
table.emit(&mut stmt);
if reader.builder.is_empty() {
stmt.write(" DEFAULT VALUES");
} else {
let (col_names, col_exprs): (Vec<_>, Vec<_>) = reader.builder.clone().into_iter().collect();
stmt.write(" (");
let mut list = ListWriter::new(&mut stmt, ", ");
for col in col_names {
list.item().write(Alias(col));
}
stmt.write(") VALUES (");
let mut list = ListWriter::new(&mut stmt, ", ");
for val in col_exprs {
list.item().write_param(&val);
}
stmt.write(")");
}
stmt.write(" RETURNING ").write(T::ID);
let res = TXN.with_borrow(|txn| {
let txn = txn.as_ref().unwrap().get();
track_stmt(txn, &stmt.sql, &stmt.params);
let mut statement = txn.prepare_cached(&stmt.sql).unwrap();
let mut res = statement
.query_map(rusqlite::params_from_iter(stmt.params), |row| {
Ok(TableRow::<T>::from_sql(row.get_ref(T::ID)?)?)
})
.unwrap();
res.next().unwrap()
});
match res {
Ok(id) => {
if let Some(idx) = idx {
assert_eq!(idx, id.inner.idx);
}
Ok(id)
}
Err(rusqlite::Error::SqliteFailure(kind, Some(msg)))
if kind.code == ErrorCode::ConstraintViolation =>
{
let res = TXN.with_borrow(|txn| {
let txn = txn.as_ref().unwrap().get();
<T::Conflict as FromConflict>::from_conflict(txn, table, reader.builder, msg)
});
Err(res)
}
Err(err) => panic!("{err:?}"),
}
}