use crate::{Connection, ConnectionEnv, FrankenError, Row, SqliteValue};
use asupersync::channel::oneshot;
use asupersync::cx::Cx as NativeCx;
use asupersync::runtime::{BlockingTaskHandle, Runtime, RuntimeBuilder, RuntimeHandle};
use fsqlite_types::cx::Cx;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc;
use std::time::Duration;
type Responder<T> = std::sync::mpsc::SyncSender<Result<T, FrankenError>>;
const WORKER_POLL_INTERVAL: Duration = Duration::from_millis(10);
enum Command {
Query {
sql: String,
tx: Responder<Vec<Row>>,
},
QueryWithParams {
sql: String,
params: Vec<SqliteValue>,
tx: Responder<Vec<Row>>,
},
QueryRow {
sql: String,
tx: Responder<Row>,
},
QueryRowWithParams {
sql: String,
params: Vec<SqliteValue>,
tx: Responder<Row>,
},
Execute {
sql: String,
tx: Responder<usize>,
},
ExecuteWithParams {
sql: String,
params: Vec<SqliteValue>,
tx: Responder<usize>,
},
ExecuteBatch {
sql: String,
tx: Responder<()>,
},
BeginTransaction {
tx: Responder<()>,
},
CommitTransaction {
tx: Responder<()>,
},
RollbackTransaction {
tx: Responder<()>,
},
Close {
tx: Responder<()>,
},
Shutdown,
}
fn worker_open_err() -> FrankenError {
FrankenError::Internal("async worker task terminated during open".to_owned())
}
fn worker_dead_err() -> FrankenError {
FrankenError::Internal("async worker task terminated unexpectedly".to_owned())
}
fn requires_runtime_err() -> FrankenError {
FrankenError::Internal(
"AsyncConnection async methods require an asupersync runtime with a blocking pool"
.to_owned(),
)
}
fn worker_spawn_err() -> FrankenError {
FrankenError::Internal(
"failed to spawn async worker task: runtime has no blocking pool".to_owned(),
)
}
fn blocking_wait_send_err<T>(_: oneshot::SendError<Result<T, FrankenError>>) {}
fn native_cx_for_local<Caps: fsqlite_types::cx::cap::SubsetOf<fsqlite_types::cx::cap::All>>(
cx: &Cx<Caps>,
) -> NativeCx {
cx.attached_native_cx()
.unwrap_or_else(NativeCx::for_request)
}
async fn recv_sync_response<
Caps: fsqlite_types::cx::cap::SubsetOf<fsqlite_types::cx::cap::All>,
T: Send + 'static,
>(
cx: &Cx<Caps>,
rx: mpsc::Receiver<T>,
) -> Result<T, FrankenError> {
let runtime = Runtime::current_handle().ok_or_else(requires_runtime_err)?;
let pool = runtime.blocking_handle().ok_or_else(requires_runtime_err)?;
let native_cx = native_cx_for_local(cx);
let waiter_cx = native_cx.clone();
let (result_tx, mut result_rx) = oneshot::channel::<Result<T, FrankenError>>();
pool.spawn(move || {
let result = rx.recv().map_err(|_| worker_dead_err());
let _ = result_tx
.send(&waiter_cx, result)
.map_err(blocking_wait_send_err);
});
match result_rx.recv(&native_cx).await {
Ok(result) => result,
Err(oneshot::RecvError::Cancelled) => Err(FrankenError::Interrupt),
Err(oneshot::RecvError::Closed | oneshot::RecvError::PolledAfterCompletion) => {
Err(worker_dead_err())
}
}
}
fn worker_loop(mut conn: Connection, rx: mpsc::Receiver<Command>, worker_cx: NativeCx) {
loop {
if worker_cx.checkpoint().is_err() {
return;
}
let cmd = match rx.recv_timeout(WORKER_POLL_INTERVAL) {
Ok(cmd) => cmd,
Err(mpsc::RecvTimeoutError::Timeout) => continue,
Err(mpsc::RecvTimeoutError::Disconnected) => return,
};
match cmd {
Command::Query { sql, tx } => {
let _ = tx.send(conn.query(&sql));
}
Command::QueryWithParams { sql, params, tx } => {
let _ = tx.send(conn.query_with_params(&sql, ¶ms));
}
Command::QueryRow { sql, tx } => {
let _ = tx.send(conn.query_row(&sql));
}
Command::QueryRowWithParams { sql, params, tx } => {
let _ = tx.send(conn.query_row_with_params(&sql, ¶ms));
}
Command::Execute { sql, tx } => {
let _ = tx.send(conn.execute(&sql));
}
Command::ExecuteWithParams { sql, params, tx } => {
let _ = tx.send(conn.execute_with_params(&sql, ¶ms));
}
Command::ExecuteBatch { sql, tx } => {
let _ = tx.send(conn.execute_batch(&sql));
}
Command::BeginTransaction { tx } => {
let _ = tx.send(conn.begin_transaction());
}
Command::CommitTransaction { tx } => {
let _ = tx.send(conn.commit_transaction());
}
Command::RollbackTransaction { tx } => {
let _ = tx.send(conn.rollback_transaction());
}
Command::Close { tx } => {
let _ = tx.send(conn.close_in_place());
return;
}
Command::Shutdown => {
return;
}
}
}
}
fn spawn_worker_task(
runtime: &RuntimeHandle,
worker_cx: NativeCx,
path: String,
env: ConnectionEnv,
cmd_rx: mpsc::Receiver<Command>,
open_tx: mpsc::SyncSender<Result<(), FrankenError>>,
) -> Result<BlockingTaskHandle, FrankenError> {
runtime
.spawn_blocking(move || match Connection::open_with_env(path, env) {
Ok(conn) => {
let _ = open_tx.send(Ok(()));
worker_loop(conn, cmd_rx, worker_cx);
}
Err(error) => {
let _ = open_tx.send(Err(error));
}
})
.ok_or_else(worker_spawn_err)
}
fn build_owned_runtime() -> Result<Runtime, FrankenError> {
RuntimeBuilder::current_thread()
.blocking_threads(1, 1)
.build()
.map_err(|error| {
FrankenError::Internal(format!("failed to build async-api runtime: {error}"))
})
}
fn current_or_owned_runtime() -> Result<(Option<Runtime>, RuntimeHandle), FrankenError> {
if let Some(handle) = Runtime::current_handle()
&& handle.blocking_handle().is_some()
{
return Ok((None, handle));
}
let runtime = build_owned_runtime()?;
let handle = runtime.handle();
Ok((Some(runtime), handle))
}
fn wait_for_worker_open(
open_rx: mpsc::Receiver<Result<(), FrankenError>>,
) -> Result<(), FrankenError> {
open_rx.recv().map_err(|_| worker_open_err())?
}
fn join_worker_task(handle: BlockingTaskHandle) {
handle.wait();
}
fn checkpoint_or_interrupt<Caps: fsqlite_types::cx::cap::SubsetOf<fsqlite_types::cx::cap::All>>(
cx: &Cx<Caps>,
) -> Result<(), FrankenError> {
cx.checkpoint().map_err(|_| FrankenError::Interrupt)
}
fn send_err<T>(_: mpsc::SendError<T>) -> FrankenError {
FrankenError::Internal("async worker task is no longer running".to_owned())
}
pub struct AsyncConnection {
cmd_tx: Option<mpsc::SyncSender<Command>>,
worker: Option<BlockingTaskHandle>,
worker_cx: Option<NativeCx>,
owned_runtime: Option<Runtime>,
in_txn: Arc<AtomicBool>,
}
impl AsyncConnection {
pub async fn open<Caps>(cx: &Cx<Caps>, path: impl Into<String>) -> Result<Self, FrankenError>
where
Caps: fsqlite_types::cx::cap::SubsetOf<fsqlite_types::cx::cap::All>,
{
Self::open_with_env(cx, path, ConnectionEnv::default()).await
}
pub fn open_sync(path: impl Into<String>) -> Result<Self, FrankenError> {
Self::open_sync_with_env(path, ConnectionEnv::default())
}
pub fn open_sync_with_env(
path: impl Into<String>,
env: ConnectionEnv,
) -> Result<Self, FrankenError> {
let path = path.into();
let (open_tx, open_rx) = mpsc::sync_channel::<Result<(), FrankenError>>(1);
let (cmd_tx, cmd_rx) = mpsc::sync_channel::<Command>(32);
let worker_cx = NativeCx::for_request();
let (owned_runtime, runtime_handle) = current_or_owned_runtime()?;
let worker = spawn_worker_task(
&runtime_handle,
worker_cx.clone(),
path,
env,
cmd_rx,
open_tx,
)?;
match wait_for_worker_open(open_rx) {
Ok(()) => Ok(Self {
cmd_tx: Some(cmd_tx),
worker: Some(worker),
worker_cx: Some(worker_cx),
owned_runtime,
in_txn: Arc::new(AtomicBool::new(false)),
}),
Err(error) => {
join_worker_task(worker);
Err(error)
}
}
}
pub async fn open_with_env<Caps>(
cx: &Cx<Caps>,
path: impl Into<String>,
env: ConnectionEnv,
) -> Result<Self, FrankenError>
where
Caps: fsqlite_types::cx::cap::SubsetOf<fsqlite_types::cx::cap::All>,
{
checkpoint_or_interrupt(cx)?;
let path = path.into();
let (open_tx, open_rx) = mpsc::sync_channel::<Result<(), FrankenError>>(1);
let (cmd_tx, cmd_rx) = mpsc::sync_channel::<Command>(32);
let runtime = Runtime::current_handle().ok_or_else(requires_runtime_err)?;
let worker_cx = NativeCx::for_request();
let worker = spawn_worker_task(&runtime, worker_cx.clone(), path, env, cmd_rx, open_tx)?;
if let Err(error) = recv_sync_response(cx, open_rx).await? {
join_worker_task(worker);
return Err(error);
}
Ok(Self {
cmd_tx: Some(cmd_tx),
worker: Some(worker),
worker_cx: Some(worker_cx),
owned_runtime: None,
in_txn: Arc::new(AtomicBool::new(false)),
})
}
fn sender(&self) -> Result<&mpsc::SyncSender<Command>, FrankenError> {
self.cmd_tx
.as_ref()
.ok_or_else(|| FrankenError::Internal("AsyncConnection has been closed".to_owned()))
}
pub async fn query<Caps>(&self, cx: &Cx<Caps>, sql: &str) -> Result<Vec<Row>, FrankenError>
where
Caps: fsqlite_types::cx::cap::SubsetOf<fsqlite_types::cx::cap::All>,
{
checkpoint_or_interrupt(cx)?;
let (tx, rx) = mpsc::sync_channel(1);
self.sender()?
.send(Command::Query {
sql: sql.to_owned(),
tx,
})
.map_err(send_err)?;
recv_sync_response(cx, rx).await?
}
pub async fn query_with_params<Caps>(
&self,
cx: &Cx<Caps>,
sql: &str,
params: &[SqliteValue],
) -> Result<Vec<Row>, FrankenError>
where
Caps: fsqlite_types::cx::cap::SubsetOf<fsqlite_types::cx::cap::All>,
{
checkpoint_or_interrupt(cx)?;
let (tx, rx) = mpsc::sync_channel(1);
self.sender()?
.send(Command::QueryWithParams {
sql: sql.to_owned(),
params: params.to_vec(),
tx,
})
.map_err(send_err)?;
recv_sync_response(cx, rx).await?
}
pub async fn query_row<Caps>(&self, cx: &Cx<Caps>, sql: &str) -> Result<Row, FrankenError>
where
Caps: fsqlite_types::cx::cap::SubsetOf<fsqlite_types::cx::cap::All>,
{
checkpoint_or_interrupt(cx)?;
let (tx, rx) = mpsc::sync_channel(1);
self.sender()?
.send(Command::QueryRow {
sql: sql.to_owned(),
tx,
})
.map_err(send_err)?;
recv_sync_response(cx, rx).await?
}
pub async fn query_row_with_params<Caps>(
&self,
cx: &Cx<Caps>,
sql: &str,
params: &[SqliteValue],
) -> Result<Row, FrankenError>
where
Caps: fsqlite_types::cx::cap::SubsetOf<fsqlite_types::cx::cap::All>,
{
checkpoint_or_interrupt(cx)?;
let (tx, rx) = mpsc::sync_channel(1);
self.sender()?
.send(Command::QueryRowWithParams {
sql: sql.to_owned(),
params: params.to_vec(),
tx,
})
.map_err(send_err)?;
recv_sync_response(cx, rx).await?
}
pub async fn execute<Caps>(&self, cx: &Cx<Caps>, sql: &str) -> Result<usize, FrankenError>
where
Caps: fsqlite_types::cx::cap::SubsetOf<fsqlite_types::cx::cap::All>,
{
checkpoint_or_interrupt(cx)?;
let (tx, rx) = mpsc::sync_channel(1);
self.sender()?
.send(Command::Execute {
sql: sql.to_owned(),
tx,
})
.map_err(send_err)?;
recv_sync_response(cx, rx).await?
}
pub async fn execute_with_params<Caps>(
&self,
cx: &Cx<Caps>,
sql: &str,
params: &[SqliteValue],
) -> Result<usize, FrankenError>
where
Caps: fsqlite_types::cx::cap::SubsetOf<fsqlite_types::cx::cap::All>,
{
checkpoint_or_interrupt(cx)?;
let (tx, rx) = mpsc::sync_channel(1);
self.sender()?
.send(Command::ExecuteWithParams {
sql: sql.to_owned(),
params: params.to_vec(),
tx,
})
.map_err(send_err)?;
recv_sync_response(cx, rx).await?
}
pub async fn execute_batch<Caps>(&self, cx: &Cx<Caps>, sql: &str) -> Result<(), FrankenError>
where
Caps: fsqlite_types::cx::cap::SubsetOf<fsqlite_types::cx::cap::All>,
{
checkpoint_or_interrupt(cx)?;
let (tx, rx) = mpsc::sync_channel(1);
self.sender()?
.send(Command::ExecuteBatch {
sql: sql.to_owned(),
tx,
})
.map_err(send_err)?;
recv_sync_response(cx, rx).await?
}
pub async fn begin_transaction<Caps>(&self, cx: &Cx<Caps>) -> Result<(), FrankenError>
where
Caps: fsqlite_types::cx::cap::SubsetOf<fsqlite_types::cx::cap::All>,
{
checkpoint_or_interrupt(cx)?;
let (tx, rx) = mpsc::sync_channel(1);
self.sender()?
.send(Command::BeginTransaction { tx })
.map_err(send_err)?;
let result: Result<(), FrankenError> = recv_sync_response(cx, rx).await?;
if result.is_ok() {
self.in_txn.store(true, Ordering::Release);
}
result
}
pub async fn commit_transaction<Caps>(&self, cx: &Cx<Caps>) -> Result<(), FrankenError>
where
Caps: fsqlite_types::cx::cap::SubsetOf<fsqlite_types::cx::cap::All>,
{
checkpoint_or_interrupt(cx)?;
let (tx, rx) = mpsc::sync_channel(1);
self.sender()?
.send(Command::CommitTransaction { tx })
.map_err(send_err)?;
let result: Result<(), FrankenError> = recv_sync_response(cx, rx).await?;
if result.is_ok() {
self.in_txn.store(false, Ordering::Release);
}
result
}
pub async fn rollback_transaction<Caps>(&self, cx: &Cx<Caps>) -> Result<(), FrankenError>
where
Caps: fsqlite_types::cx::cap::SubsetOf<fsqlite_types::cx::cap::All>,
{
checkpoint_or_interrupt(cx)?;
let (tx, rx) = mpsc::sync_channel(1);
self.sender()?
.send(Command::RollbackTransaction { tx })
.map_err(send_err)?;
let result: Result<(), FrankenError> = recv_sync_response(cx, rx).await?;
if result.is_ok() {
self.in_txn.store(false, Ordering::Release);
}
result
}
#[must_use]
pub fn in_transaction(&self) -> bool {
self.in_txn.load(Ordering::Acquire)
}
pub async fn close<Caps>(&mut self, cx: &Cx<Caps>) -> Result<(), FrankenError>
where
Caps: fsqlite_types::cx::cap::SubsetOf<fsqlite_types::cx::cap::All>,
{
checkpoint_or_interrupt(cx)?;
if let Some(cmd_tx) = self.cmd_tx.take() {
let (tx, rx) = mpsc::sync_channel(1);
cmd_tx.send(Command::Close { tx }).map_err(send_err)?;
let result = recv_sync_response(cx, rx).await?;
if let Some(worker_cx) = self.worker_cx.take() {
worker_cx.cancel();
}
if let Some(handle) = self.worker.take() {
join_worker_task(handle);
}
self.owned_runtime = None;
result
} else {
Ok(())
}
}
}
impl Drop for AsyncConnection {
fn drop(&mut self) {
if let Some(cmd_tx) = self.cmd_tx.take() {
let _ = cmd_tx.send(Command::Shutdown);
}
if let Some(worker_cx) = self.worker_cx.take() {
worker_cx.cancel();
}
if let Some(handle) = self.worker.take() {
join_worker_task(handle);
}
self.owned_runtime = None;
}
}
#[cfg(test)]
mod tests {
use super::*;
use asupersync::runtime::RuntimeBuilder;
use fsqlite_types::cx::Cx;
fn test_runtime() -> Runtime {
RuntimeBuilder::current_thread()
.blocking_threads(2, 2)
.build()
.expect("test runtime should build")
}
#[test]
fn test_async_connection_basic() {
test_runtime().block_on(async {
let cx = Cx::new();
let conn = AsyncConnection::open(&cx, ":memory:")
.await
.expect("open should succeed");
conn.execute(&cx, "CREATE TABLE t (id INTEGER PRIMARY KEY, name TEXT)")
.await
.expect("create table should succeed");
conn.execute_with_params(
&cx,
"INSERT INTO t VALUES (?1, ?2)",
&[SqliteValue::Integer(1), SqliteValue::Text("hello".into())],
)
.await
.expect("insert should succeed");
let rows = conn
.query(&cx, "SELECT * FROM t")
.await
.expect("query should succeed");
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get(0), Some(&SqliteValue::Integer(1)));
assert_eq!(rows[0].get(1), Some(&SqliteValue::Text("hello".into())));
let row = conn
.query_row(&cx, "SELECT name FROM t WHERE id = 1")
.await
.expect("query_row should succeed");
assert_eq!(row.get(0), Some(&SqliteValue::Text("hello".into())));
let count = conn
.execute(&cx, "DELETE FROM t")
.await
.expect("delete should succeed");
assert_eq!(count, 1);
});
}
#[test]
fn test_async_connection_transaction() {
test_runtime().block_on(async {
let cx = Cx::new();
let conn = AsyncConnection::open(&cx, ":memory:")
.await
.expect("open should succeed");
conn.execute(&cx, "CREATE TABLE t (id INTEGER PRIMARY KEY)")
.await
.expect("create should succeed");
conn.begin_transaction(&cx).await.expect("begin");
conn.execute(&cx, "INSERT INTO t VALUES (1)")
.await
.expect("insert");
conn.rollback_transaction(&cx).await.expect("rollback");
let rows = conn.query(&cx, "SELECT * FROM t").await.expect("query");
assert!(rows.is_empty(), "rollback should have removed the row");
conn.begin_transaction(&cx).await.expect("begin");
conn.execute(&cx, "INSERT INTO t VALUES (2)")
.await
.expect("insert");
conn.commit_transaction(&cx).await.expect("commit");
let rows = conn.query(&cx, "SELECT * FROM t").await.expect("query");
assert_eq!(rows.len(), 1);
});
}
#[test]
fn test_async_connection_cancel() {
test_runtime().block_on(async {
let cx = Cx::new();
let conn = AsyncConnection::open(&cx, ":memory:")
.await
.expect("open should succeed");
cx.cancel();
let result = conn.execute(&cx, "SELECT 1").await;
assert!(result.is_err(), "operation should fail after cancellation");
match result.unwrap_err() {
FrankenError::Interrupt => {}
other => panic!("expected Interrupt, got: {other}"),
}
});
}
#[test]
fn test_async_connection_execute_batch() {
test_runtime().block_on(async {
let cx = Cx::new();
let conn = AsyncConnection::open(&cx, ":memory:")
.await
.expect("open should succeed");
conn.execute_batch(&cx, "CREATE TABLE a (x INTEGER); CREATE TABLE b (y TEXT);")
.await
.expect("batch should succeed");
let _ = conn.query(&cx, "SELECT * FROM a").await.expect("table a");
let _ = conn.query(&cx, "SELECT * FROM b").await.expect("table b");
});
}
#[test]
fn test_async_connection_close() {
test_runtime().block_on(async {
let cx = Cx::new();
let mut conn = AsyncConnection::open(&cx, ":memory:")
.await
.expect("open should succeed");
conn.close(&cx).await.expect("close should succeed");
let result = conn.query(&cx, "SELECT 1").await;
assert!(result.is_err(), "query after close should fail");
});
}
}