use std::{
path::{Path, PathBuf},
thread,
};
use crossbeam_channel::{Sender, bounded, unbounded};
use rusqlite::{Connection, OpenFlags};
use tokio::sync::oneshot;
use crate::errors::Result;
#[derive(Clone, Debug, Default)]
pub struct SqliteClientBuilder {
pub(crate) path: Option<PathBuf>,
pub(crate) flags: OpenFlags,
}
impl SqliteClientBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn path<P: AsRef<Path>>(mut self, path: P) -> Self {
self.path = Some(path.as_ref().into());
self
}
pub fn flags(mut self, flags: OpenFlags) -> Self {
self.flags = flags;
self
}
pub async fn open(self) -> Result<SqliteClient, Error> {
SqliteClient::open(self).await
}
}
pub struct SqliteClient {
conn_tx: Sender<Command>,
}
impl SqliteClient {
async fn open(mut builder: SqliteClientBuilder) -> Result<Self, Error> {
let path = builder.path.take().unwrap_or_else(|| ":memory:".into());
let (open_tx, open_rx) = oneshot::channel();
thread::spawn(move || {
let (conn_tx, conn_rx) = unbounded();
let mut conn = match Connection::open_with_flags(path, builder.flags) {
Ok(conn) => conn,
Err(err) => {
if let Err(Err(err)) = open_tx.send(Err(err)) {
tracing::error!("Error sending sqlite connection error: {err:?}");
}
return;
}
};
let client = Self { conn_tx };
if open_tx.send(Ok(client)).is_err() {
tracing::error!("Error sending sqlite connection");
}
while let Ok(cmd) = conn_rx.recv() {
match cmd {
Command::Func(func) => func(&mut conn),
Command::Shutdown(func) => match conn.close() {
Ok(()) => {
func(Ok(()));
return;
}
Err((c, e)) => {
conn = c;
func(Err(e.into()));
}
},
}
}
});
Ok(open_rx.await??)
}
}
impl SqliteClient {
pub async fn conn<F, T>(&self, func: F) -> Result<T>
where
F: FnOnce(&Connection) -> Result<T> + Send + 'static,
T: Send + 'static,
{
let (tx, rx) = oneshot::channel();
self.conn_tx
.send(Command::Func(Box::new(move |conn| {
if tx.send(func(conn)).is_err() {
tracing::error!("Error sending sqlite response");
}
})))
.map_err(Error::from)?;
rx.await.map_err(Error::from)?
}
pub async fn conn_mut<F, T>(&self, func: F) -> Result<T>
where
F: FnOnce(&mut Connection) -> Result<T> + Send + 'static,
T: Send + 'static,
{
let (tx, rx) = oneshot::channel();
self.conn_tx
.send(Command::Func(Box::new(move |conn| {
if tx.send(func(conn)).is_err() {
tracing::error!("Error sending sqlite response");
}
})))
.map_err(Error::from)?;
rx.await.map_err(Error::from)?
}
pub fn close_blocking(&self) -> Result<(), Error> {
let (tx, rx) = bounded(1);
let func = Box::new(move |res| _ = tx.send(res));
if self.conn_tx.send(Command::Shutdown(func)).is_err() {
return Ok(());
}
rx.recv().unwrap_or(Ok(()))
}
}
impl Drop for SqliteClient {
fn drop(&mut self) {
if let Err(err) = self.close_blocking() {
tracing::error!("Error closing sqlite client: {err:?}");
}
}
}
enum Command {
Func(Box<dyn FnOnce(&mut Connection) + Send>),
Shutdown(Box<dyn FnOnce(Result<(), Error>) + Send>),
}
#[derive(Debug)]
pub enum Error {
Closed,
Rusqlite(rusqlite::Error),
}
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Error::Closed => write!(f, "connection to sqlite database closed"),
Error::Rusqlite(err) => err.fmt(f),
}
}
}
impl std::error::Error for Error {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Error::Rusqlite(err) => Some(err),
_ => None,
}
}
}
impl<T> From<crossbeam_channel::SendError<T>> for Error {
fn from(_value: crossbeam_channel::SendError<T>) -> Self {
Error::Closed
}
}
impl From<crossbeam_channel::RecvError> for Error {
fn from(_value: crossbeam_channel::RecvError) -> Self {
Error::Closed
}
}
impl From<oneshot::error::RecvError> for Error {
fn from(_value: oneshot::error::RecvError) -> Self {
Error::Closed
}
}
impl From<rusqlite::Error> for Error {
fn from(value: rusqlite::Error) -> Self {
Error::Rusqlite(value)
}
}