use crate::DatabaseIdentifier;
use sqlx::{Pool, Sqlite, SqliteConnection, Transaction};
#[derive(Clone, Debug)]
pub struct DbRead<I: DatabaseIdentifier> {
pub(crate) pool: Pool<Sqlite>,
pub(crate) identifier: I,
}
impl<I: DatabaseIdentifier> DbRead<I> {
pub(crate) fn new(pool: Pool<Sqlite>, identifier: I) -> Self {
Self { pool, identifier }
}
pub fn identifier(&self) -> &I {
&self.identifier
}
pub fn pool(&self) -> &Pool<Sqlite> {
&self.pool
}
pub async fn begin(&self) -> sqlx::Result<TxRead<I>> {
let tx = self.pool.begin().await?;
Ok(TxRead::new(tx, self.identifier.clone()))
}
}
#[derive(Clone, Debug)]
pub struct DbWrite<I: DatabaseIdentifier>(DbRead<I>);
impl<I: DatabaseIdentifier> DbWrite<I> {
pub(crate) fn new(pool: Pool<Sqlite>, identifier: I) -> Self {
Self(DbRead::new(pool, identifier))
}
pub fn identifier(&self) -> &I {
self.0.identifier()
}
pub fn pool(&self) -> &Pool<Sqlite> {
self.0.pool()
}
pub async fn begin(&self) -> sqlx::Result<TxWrite<I>> {
let tx = self.pool().begin().await?;
Ok(TxWrite(TxRead::new(tx, self.0.identifier.clone())))
}
}
pub struct TxRead<I: DatabaseIdentifier> {
tx: Option<Transaction<'static, Sqlite>>,
identifier: I,
}
impl<I: DatabaseIdentifier> TxRead<I> {
pub(crate) fn new(tx: Transaction<'static, Sqlite>, identifier: I) -> Self {
Self {
tx: Some(tx),
identifier,
}
}
pub fn identifier(&self) -> &I {
&self.identifier
}
pub async fn close(mut self) -> sqlx::Result<()> {
self.tx
.take()
.expect("transaction already consumed")
.rollback()
.await
}
pub(crate) fn conn_mut(&mut self) -> &mut SqliteConnection {
self.tx.as_mut().expect("transaction already consumed")
}
pub(crate) fn tx_mut(&mut self) -> &mut Transaction<'static, Sqlite> {
self.tx.as_mut().expect("transaction already consumed")
}
}
pub struct TxWrite<I: DatabaseIdentifier>(TxRead<I>);
impl<I: DatabaseIdentifier> TxWrite<I> {
pub fn identifier(&self) -> &I {
self.0.identifier()
}
pub async fn commit(mut self) -> sqlx::Result<()> {
self.0
.tx
.take()
.expect("transaction already consumed")
.commit()
.await
}
pub async fn rollback(mut self) -> sqlx::Result<()> {
self.0
.tx
.take()
.expect("transaction already consumed")
.rollback()
.await
}
pub(crate) fn conn_mut(&mut self) -> &mut SqliteConnection {
self.0.conn_mut()
}
pub(crate) fn tx_mut(&mut self) -> &mut Transaction<'static, Sqlite> {
self.0.tx_mut()
}
}
impl<I: DatabaseIdentifier> From<DbWrite<I>> for DbRead<I> {
fn from(write: DbWrite<I>) -> Self {
write.0
}
}
impl<I: DatabaseIdentifier> AsRef<DbRead<I>> for DbWrite<I> {
fn as_ref(&self) -> &DbRead<I> {
&self.0
}
}
impl<I: DatabaseIdentifier> From<TxWrite<I>> for TxRead<I> {
fn from(write: TxWrite<I>) -> Self {
write.0
}
}
impl<I: DatabaseIdentifier> AsRef<TxRead<I>> for TxWrite<I> {
fn as_ref(&self) -> &TxRead<I> {
&self.0
}
}
impl<I: DatabaseIdentifier> AsMut<TxRead<I>> for TxWrite<I> {
fn as_mut(&mut self) -> &mut TxRead<I> {
&mut self.0
}
}