use crate::drivers::Driver;
use crate::pool::{
ConnectionAdapter, ConnectionPool, ConnectionPoolConfig, ConnectionPoolError, PooledConnection,
};
use crate::{ReadTransaction, Transaction};
use flume::Sender;
use std::marker::PhantomData;
use std::sync::Arc;
use tracing::error;
#[derive(Debug, thiserror::Error)]
pub enum AsyncConnectionError<E> {
#[error(transparent)]
Connection(#[from] E),
#[error("Communication with worker failed")]
Worker,
}
impl<E> AsyncConnectionError<E> {
fn into<T: From<E>>(self) -> AsyncConnectionError<T> {
match self {
AsyncConnectionError::Connection(e) => AsyncConnectionError::Connection(T::from(e)),
AsyncConnectionError::Worker => AsyncConnectionError::Worker,
}
}
}
pub trait AsyncRuntime: Send + 'static + Sync {
type JoinError: std::error::Error + Send + Sync + 'static;
type JoinHandle<T: Send + 'static>: Future<Output = Result<T, Self::JoinError>> + Send + 'static;
fn spawn_blocking<F, T>(closure: F) -> Self::JoinHandle<T>
where
F: FnOnce() -> T + Send + 'static,
T: Send + 'static;
}
pub type AsyncConnectionPool<T> = ConnectionPool<T, AsyncConnectionAdapter<T>>;
pub type AsyncPooledConnection<T> = PooledConnection<T, AsyncConnectionAdapter<T>>;
impl<T, A> ConnectionPool<T, A>
where
T: Driver,
A: ConnectionAdapter<T> + Send + 'static,
<T as Driver>::Connection: Send + 'static,
<T as Driver>::Error: Send + 'static,
<T as Driver>::ConnectionError: Send + 'static,
{
pub async fn new_async<R: AsyncRuntime>(
config: ConnectionPoolConfig,
) -> Result<Arc<Self>, ConnectionPoolError<T::ConnectionError>> {
R::spawn_blocking(move || Self::new(config))
.await
.map_err(|e| ConnectionPoolError::Other(Box::new(e)))?
}
#[allow(clippy::type_complexity)]
pub async fn connection_async<R: AsyncRuntime>(
self: &Arc<Self>,
) -> Result<PooledConnection<T, A>, ConnectionPoolError<T::Error>> {
let pool = self.clone();
R::spawn_blocking(move || pool.connection())
.await
.map_err(|e| ConnectionPoolError::Other(Box::new(e)))?
}
}
pub struct AsyncConnectionAdapter<T: Driver> {
sender: Sender<AsyncCommand<T::Connection>>,
}
impl<T> AsyncConnectionAdapter<T>
where
T: Driver,
<T as Driver>::Connection: Send + 'static,
<T as Driver>::Error: Send + 'static,
{
fn new(mut connection: T::Connection) -> Result<Self, std::io::Error> {
let (sender, receiver) = flume::bounded(1);
std::thread::Builder::new()
.name("sqlite-wrc-async".to_owned())
.spawn(move || {
while let Ok(cmd) = receiver.recv() {
match cmd {
AsyncCommand::Execute(execute) => {
execute(&mut connection);
}
}
}
})?;
Ok(AsyncConnectionAdapter { sender })
}
async fn read<F, R, E>(&self, closure: F) -> Result<R, AsyncConnectionError<E>>
where
F: FnOnce(&mut T::Connection) -> Result<R, E> + Send + 'static,
R: Send + 'static,
E: From<T::Error> + Send + 'static,
{
let (sx, rx) = oneshot::channel();
self.sender
.send_async(AsyncCommand::Execute(Box::new(
move |conn: &mut T::Connection| {
let r = closure(conn);
let _ = sx.send(r);
},
)))
.await
.map_err(|_| AsyncConnectionError::Worker)?;
Ok(rx.await.map_err(|_| AsyncConnectionError::Worker)??)
}
async fn read_transaction_closure<F, R, E>(
&self,
closure: F,
) -> Result<R, AsyncConnectionError<E>>
where
F: FnOnce(&mut ReadTransaction<'_, T>) -> Result<R, E> + Send + 'static,
R: Send + 'static,
E: From<T::Error> + Send + 'static,
{
let (sx, rx) = oneshot::channel();
self.sender
.send_async(AsyncCommand::Execute(Box::new(
move |conn: &mut T::Connection| {
let r = ReadTransaction::scoped(conn, closure);
let _ = sx.send(r);
},
)))
.await
.map_err(|_| AsyncConnectionError::Worker)?;
Ok(rx.await.map_err(|_| AsyncConnectionError::Worker)??)
}
async fn transaction_closure<F, R, E>(
&mut self,
pool: Arc<ConnectionPool<T, Self>>,
closure: F,
) -> Result<R, AsyncConnectionError<E>>
where
F: FnOnce(&mut Transaction<'_, T>) -> Result<R, E> + Send + 'static,
R: Send + 'static,
E: From<T::Error> + Send + 'static,
{
let (sx, rx) = oneshot::channel();
self.sender
.send_async(AsyncCommand::Execute(Box::new(
move |_: &mut T::Connection| {
let r = pool.transaction_closure(closure);
let _ = sx.send(r);
},
)))
.await
.map_err(|_| AsyncConnectionError::Worker)?;
Ok(rx.await.map_err(|_| AsyncConnectionError::Worker)??)
}
async fn transaction_closure_async<F, R, E>(
&mut self,
pool: Arc<ConnectionPool<T, Self>>,
closure: F,
) -> Result<R, AsyncConnectionError<E>>
where
F: AsyncFnOnce(&mut AsyncTransaction<T>) -> Result<R, AsyncConnectionError<E>>,
R: Send + 'static,
E: From<T::Error> + Send + 'static,
{
let mut tx = self
.transaction(pool)
.await
.map_err(AsyncConnectionError::into)?;
let r = closure(&mut tx).await;
let tx_result = if r.is_ok() {
tx.commit().await
} else {
tx.rollback().await
};
tx_result.map_err(AsyncConnectionError::into)?;
r
}
async fn read_transaction_closure_async<F, R, E>(
&mut self,
closure: F,
) -> Result<R, AsyncConnectionError<E>>
where
F: AsyncFnOnce(&mut AsyncReadTransaction<T>) -> Result<R, AsyncConnectionError<E>>,
R: Send + 'static,
E: From<T::Error> + Send + 'static,
{
let mut tx = self
.read_transaction()
.await
.map_err(AsyncConnectionError::into)?;
let r = closure(&mut tx).await;
tx.commit().await.map_err(AsyncConnectionError::into)?;
r
}
async fn transaction(
&mut self,
pool: Arc<ConnectionPool<T, Self>>,
) -> Result<AsyncTransaction<'_, T>, AsyncConnectionError<T::Error>> {
let (sx, rx) = oneshot::channel();
self.sender
.send_async(AsyncCommand::Execute(Box::new(
move |_: &mut T::Connection| {
let (async_rx, mut tx) = match pool.transaction() {
Ok(tx) => {
let (async_tx, async_rx) = flume::bounded(1);
let _ = sx.send(Ok(async_tx));
(async_rx, tx)
}
Err(e) => {
let _ = sx.send(Err(e));
return;
}
};
while let Ok(cmd) = async_rx.recv() {
match cmd {
AsyncTxCommand::Op(op) => op(&mut tx),
AsyncTxCommand::Commit(sx) => {
let _ = sx.send(tx.commit());
return;
}
AsyncTxCommand::Rollback(sx) => {
let _ = sx.send(tx.rollback());
return;
}
AsyncTxCommand::ReadOp(_) => {
tracing::warn!("Received invalid command");
}
}
}
},
)))
.await
.map_err(|_| AsyncConnectionError::Worker)?;
Ok(rx
.await
.map_err(|_| AsyncConnectionError::Worker)?
.map(AsyncTransaction::new)?)
}
async fn read_transaction(
&mut self,
) -> Result<AsyncReadTransaction<'_, T>, AsyncConnectionError<T::Error>> {
let (sx, rx) = oneshot::channel();
self.sender
.send_async(AsyncCommand::Execute(Box::new(
move |conn: &mut T::Connection| {
let (async_rx, mut tx): (_, ReadTransaction<T>) =
match ReadTransaction::new(conn) {
Ok(tx) => {
let (async_tx, async_rx) = flume::bounded(1);
let _ = sx.send(Ok(async_tx));
(async_rx, tx)
}
Err(e) => {
let _ = sx.send(Err(e));
return;
}
};
while let Ok(cmd) = async_rx.recv() {
match cmd {
AsyncTxCommand::ReadOp(op) => op(&mut tx),
AsyncTxCommand::Commit(sx) => {
let _ = sx.send(tx.commit());
return;
}
AsyncTxCommand::Rollback(sx) => {
let _ = sx.send(tx.rollback());
return;
}
AsyncTxCommand::Op(_) => {
tracing::warn!("Received invalid command");
}
}
}
},
)))
.await
.map_err(|_| AsyncConnectionError::Worker)?;
Ok(rx
.await
.map_err(|_| AsyncConnectionError::Worker)?
.map(AsyncReadTransaction::new)?)
}
}
impl<T> ConnectionAdapter<T> for AsyncConnectionAdapter<T>
where
T: Driver,
<T as Driver>::Connection: Send + 'static,
<T as Driver>::Error: Send + 'static,
{
fn from_driver_connection(connection: T::Connection) -> Self {
AsyncConnectionAdapter::new(connection).expect("Failed to create named thread")
}
}
impl<T> PooledConnection<T, AsyncConnectionAdapter<T>>
where
T: Driver,
<T as Driver>::Connection: Send + 'static,
<T as Driver>::Error: Send + 'static,
{
pub async fn read<F, R, E>(&mut self, closure: F) -> Result<R, AsyncConnectionError<E>>
where
F: FnOnce(&mut T::Connection) -> Result<R, E> + Send + 'static,
R: Send + 'static,
E: From<T::Error> + Send + 'static,
{
self.connection_mut().read(closure).await
}
pub async fn read_transaction_closure<F, R, E>(
&mut self,
closure: F,
) -> Result<R, AsyncConnectionError<E>>
where
F: FnOnce(&mut ReadTransaction<'_, T>) -> Result<R, E> + Send + 'static,
R: Send + 'static,
E: From<T::Error> + Send + 'static,
{
self.connection_mut()
.read_transaction_closure(closure)
.await
}
pub async fn transaction_closure<F, R, E>(
&mut self,
closure: F,
) -> Result<R, AsyncConnectionError<E>>
where
F: FnOnce(&mut Transaction<'_, T>) -> Result<R, E> + Send + 'static,
R: Send + 'static,
E: From<T::Error> + Send + 'static,
{
let pool = self.pool.clone();
self.connection_mut()
.transaction_closure(pool, closure)
.await
}
pub async fn transaction_closure_async<F, R, E>(
&mut self,
closure: F,
) -> Result<R, AsyncConnectionError<E>>
where
F: AsyncFnOnce(&mut AsyncTransaction<T>) -> Result<R, AsyncConnectionError<E>>,
R: Send + 'static,
E: From<T::Error> + Send + 'static,
{
let pool = self.pool.clone();
self.connection_mut()
.transaction_closure_async(pool, closure)
.await
}
pub async fn read_transaction_closure_async<F, R, E>(
&mut self,
closure: F,
) -> Result<R, AsyncConnectionError<E>>
where
F: AsyncFnOnce(&mut AsyncReadTransaction<T>) -> Result<R, AsyncConnectionError<E>>,
R: Send + 'static,
E: From<T::Error> + Send + 'static,
{
self.connection_mut()
.read_transaction_closure_async(closure)
.await
}
pub async fn transaction(
&mut self,
) -> Result<AsyncTransaction<'_, T>, AsyncConnectionError<T::Error>> {
let pool = self.pool.clone();
self.connection_mut().transaction(pool).await
}
pub async fn read_transaction(
&mut self,
) -> Result<AsyncReadTransaction<'_, T>, AsyncConnectionError<T::Error>> {
self.connection_mut().read_transaction().await
}
}
enum AsyncTxCommand<T: Driver> {
#[allow(clippy::type_complexity)]
Op(Box<dyn FnOnce(&mut Transaction<'_, T>) + Send + 'static>),
#[allow(clippy::type_complexity)]
ReadOp(Box<dyn FnOnce(&mut ReadTransaction<'_, T>) + Send + 'static>),
Commit(oneshot::Sender<Result<(), T::Error>>),
Rollback(oneshot::Sender<Result<(), T::Error>>),
}
pub struct AsyncTransaction<'c, T: Driver>(
flume::Sender<AsyncTxCommand<T>>,
PhantomData<&'c mut T>,
);
impl<T: Driver> AsyncTransaction<'_, T> {
fn new(sender: flume::Sender<AsyncTxCommand<T>>) -> Self {
Self(sender, PhantomData)
}
pub async fn run<F, R, E>(&mut self, closure: F) -> Result<R, AsyncConnectionError<E>>
where
F: FnOnce(&mut Transaction<'_, T>) -> Result<R, E> + Send + 'static,
R: Send + 'static,
E: From<T::Error> + Send + 'static,
{
let (sx, rx) = oneshot::channel();
self.0
.send_async(AsyncTxCommand::Op(Box::new(
move |tx: &mut Transaction<'_, T>| {
let r = closure(tx);
let _ = sx.send(r);
},
)))
.await
.map_err(|_| AsyncConnectionError::Worker)?;
Ok(rx.await.map_err(|_| AsyncConnectionError::Worker)??)
}
pub async fn commit(self) -> Result<(), AsyncConnectionError<T::Error>> {
let (sx, rx) = oneshot::channel();
self.0
.send_async(AsyncTxCommand::Commit(sx))
.await
.map_err(|_| AsyncConnectionError::Worker)?;
Ok(rx.await.map_err(|_| AsyncConnectionError::Worker)??)
}
pub async fn rollback(self) -> Result<(), AsyncConnectionError<T::Error>> {
let (sx, rx) = oneshot::channel();
self.0
.send_async(AsyncTxCommand::Rollback(sx))
.await
.map_err(|_| AsyncConnectionError::Worker)?;
Ok(rx.await.map_err(|_| AsyncConnectionError::Worker)??)
}
}
pub struct AsyncReadTransaction<'c, T: Driver>(
flume::Sender<AsyncTxCommand<T>>,
PhantomData<&'c mut T>,
);
impl<T: Driver> AsyncReadTransaction<'_, T> {
fn new(sender: flume::Sender<AsyncTxCommand<T>>) -> Self {
Self(sender, PhantomData)
}
pub async fn run<F, R, E>(&mut self, closure: F) -> Result<R, AsyncConnectionError<E>>
where
F: FnOnce(&mut ReadTransaction<'_, T>) -> Result<R, E> + Send + 'static,
R: Send + 'static,
E: From<T::Error> + Send + 'static,
{
let (sx, rx) = oneshot::channel();
self.0
.send_async(AsyncTxCommand::ReadOp(Box::new(
move |tx: &mut ReadTransaction<'_, T>| {
let r = closure(tx);
let _ = sx.send(r);
},
)))
.await
.map_err(|_| AsyncConnectionError::Worker)?;
Ok(rx.await.map_err(|_| AsyncConnectionError::Worker)??)
}
pub async fn commit(self) -> Result<(), AsyncConnectionError<T::Error>> {
let (sx, rx) = oneshot::channel();
self.0
.send_async(AsyncTxCommand::Commit(sx))
.await
.map_err(|_| AsyncConnectionError::Worker)?;
Ok(rx.await.map_err(|_| AsyncConnectionError::Worker)??)
}
}
enum AsyncCommand<T> {
Execute(Box<dyn FnOnce(&mut T) + Send + 'static>),
}