use std::{
path::{Path, PathBuf},
thread,
};
use crate::Error;
use crossbeam_channel::{bounded, unbounded, Sender};
use futures_channel::oneshot;
use rusqlite::{Connection, OpenFlags};
#[derive(Clone, Debug, Default)]
pub struct ClientBuilder {
pub(crate) path: Option<PathBuf>,
pub(crate) flags: OpenFlags,
pub(crate) journal_mode: Option<JournalMode>,
pub(crate) vfs: Option<String>,
}
impl ClientBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn path<P: AsRef<Path>>(mut self, path: P) -> Self {
self.path = Some(path.as_ref().into());
self
}
pub fn flags(mut self, flags: OpenFlags) -> Self {
self.flags = flags;
self
}
pub fn journal_mode(mut self, journal_mode: JournalMode) -> Self {
self.journal_mode = Some(journal_mode);
self
}
pub fn vfs(mut self, vfs: &str) -> Self {
self.vfs = Some(vfs.to_owned());
self
}
pub async fn open(self) -> Result<Client, Error> {
Client::open_async(self).await
}
pub fn open_blocking(self) -> Result<Client, Error> {
Client::open_blocking(self)
}
}
enum Command {
Func(Box<dyn FnOnce(&mut Connection) + Send>),
Shutdown(Box<dyn FnOnce(Result<(), Error>) + Send>),
}
#[derive(Clone)]
pub struct Client {
conn_tx: Sender<Command>,
}
impl Client {
async fn open_async(builder: ClientBuilder) -> Result<Self, Error> {
let (open_tx, open_rx) = oneshot::channel();
Self::open(builder, |res| {
_ = open_tx.send(res);
});
open_rx.await?
}
fn open_blocking(builder: ClientBuilder) -> Result<Self, Error> {
let (conn_tx, conn_rx) = bounded(1);
Self::open(builder, move |res| {
_ = conn_tx.send(res);
});
conn_rx.recv()?
}
fn open<F>(builder: ClientBuilder, func: F)
where
F: FnOnce(Result<Self, Error>) + Send + 'static,
{
thread::spawn(move || {
let (conn_tx, conn_rx) = unbounded();
let mut conn = match Client::create_conn(builder) {
Ok(conn) => conn,
Err(err) => {
func(Err(err));
return;
}
};
let client = Self { conn_tx };
func(Ok(client));
while let Ok(cmd) = conn_rx.recv() {
match cmd {
Command::Func(func) => func(&mut conn),
Command::Shutdown(func) => match conn.close() {
Ok(()) => {
func(Ok(()));
return;
}
Err((c, e)) => {
conn = c;
func(Err(e.into()));
}
},
}
}
});
}
fn create_conn(mut builder: ClientBuilder) -> Result<Connection, Error> {
let path = builder.path.take().unwrap_or_else(|| ":memory:".into());
let conn = if let Some(vfs) = builder.vfs.take() {
Connection::open_with_flags_and_vfs(path, builder.flags, vfs.as_str())?
} else {
Connection::open_with_flags(path, builder.flags)?
};
if let Some(journal_mode) = builder.journal_mode.take() {
let val = journal_mode.as_str();
let out: String =
conn.pragma_update_and_check(None, "journal_mode", val, |row| row.get(0))?;
if !out.eq_ignore_ascii_case(val) {
return Err(Error::PragmaUpdate {
name: "journal_mode",
exp: val,
got: out,
});
}
}
Ok(conn)
}
pub async fn conn<F, T>(&self, func: F) -> Result<T, Error>
where
F: FnOnce(&Connection) -> Result<T, rusqlite::Error> + Send + 'static,
T: Send + 'static,
{
let (tx, rx) = oneshot::channel();
self.conn_tx.send(Command::Func(Box::new(move |conn| {
_ = tx.send(func(conn));
})))?;
Ok(rx.await??)
}
pub async fn conn_mut<F, T>(&self, func: F) -> Result<T, Error>
where
F: FnOnce(&mut Connection) -> Result<T, rusqlite::Error> + Send + 'static,
T: Send + 'static,
{
let (tx, rx) = oneshot::channel();
self.conn_tx.send(Command::Func(Box::new(move |conn| {
_ = tx.send(func(conn));
})))?;
Ok(rx.await??)
}
pub async fn conn_and_then<F, T, E>(&self, func: F) -> Result<T, E>
where
F: FnOnce(&Connection) -> Result<T, E> + Send + 'static,
T: Send + 'static,
E: From<rusqlite::Error> + From<Error> + Send + 'static,
{
let (tx, rx) = oneshot::channel();
self.conn_tx
.send(Command::Func(Box::new(move |conn| {
_ = tx.send(func(conn));
})))
.map_err(Error::from)?;
rx.await.map_err(Error::from)?
}
pub async fn conn_mut_and_then<F, T, E>(&self, func: F) -> Result<T, E>
where
F: FnOnce(&mut Connection) -> Result<T, E> + Send + 'static,
T: Send + 'static,
E: From<rusqlite::Error> + From<Error> + Send + 'static,
{
let (tx, rx) = oneshot::channel();
self.conn_tx
.send(Command::Func(Box::new(move |conn| {
_ = tx.send(func(conn));
})))
.map_err(Error::from)?;
rx.await.map_err(Error::from)?
}
pub async fn close(&self) -> Result<(), Error> {
let (tx, rx) = oneshot::channel();
let func = Box::new(|res| _ = tx.send(res));
if self.conn_tx.send(Command::Shutdown(func)).is_err() {
return Ok(());
}
rx.await.unwrap_or(Ok(()))
}
pub fn conn_blocking<F, T>(&self, func: F) -> Result<T, Error>
where
F: FnOnce(&Connection) -> Result<T, rusqlite::Error> + Send + 'static,
T: Send + 'static,
{
let (tx, rx) = bounded(1);
self.conn_tx.send(Command::Func(Box::new(move |conn| {
_ = tx.send(func(conn));
})))?;
Ok(rx.recv()??)
}
pub fn conn_mut_blocking<F, T>(&self, func: F) -> Result<T, Error>
where
F: FnOnce(&mut Connection) -> Result<T, rusqlite::Error> + Send + 'static,
T: Send + 'static,
{
let (tx, rx) = bounded(1);
self.conn_tx.send(Command::Func(Box::new(move |conn| {
_ = tx.send(func(conn));
})))?;
Ok(rx.recv()??)
}
pub fn close_blocking(&self) -> Result<(), Error> {
let (tx, rx) = bounded(1);
let func = Box::new(move |res| _ = tx.send(res));
if self.conn_tx.send(Command::Shutdown(func)).is_err() {
return Ok(());
}
rx.recv().unwrap_or(Ok(()))
}
}
#[derive(Clone, Copy, Debug)]
pub enum JournalMode {
Delete,
Truncate,
Persist,
Memory,
Wal,
Off,
}
impl JournalMode {
pub fn as_str(&self) -> &'static str {
match self {
Self::Delete => "DELETE",
Self::Truncate => "TRUNCATE",
Self::Persist => "PERSIST",
Self::Memory => "MEMORY",
Self::Wal => "WAL",
Self::Off => "OFF",
}
}
}