use std::{
num::NonZeroUsize,
path::{Path, PathBuf},
sync::{
atomic::{AtomicU64, Ordering::Relaxed},
Arc,
},
thread::available_parallelism,
};
use crate::{Client, ClientBuilder, Error, JournalMode};
use futures_util::future::join_all;
use rusqlite::{Connection, OpenFlags};
#[derive(Clone, Debug, Default)]
pub struct PoolBuilder {
path: Option<PathBuf>,
flags: OpenFlags,
journal_mode: Option<JournalMode>,
vfs: Option<String>,
num_conns: Option<usize>,
}
impl PoolBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn path<P: AsRef<Path>>(mut self, path: P) -> Self {
self.path = Some(path.as_ref().into());
self
}
pub fn flags(mut self, flags: OpenFlags) -> Self {
self.flags = flags;
self
}
pub fn journal_mode(mut self, journal_mode: JournalMode) -> Self {
self.journal_mode = Some(journal_mode);
self
}
pub fn vfs(mut self, vfs: &str) -> Self {
self.vfs = Some(vfs.to_owned());
self
}
pub fn num_conns(mut self, num_conns: usize) -> Self {
self.num_conns = Some(num_conns.max(1));
self
}
pub async fn open(self) -> Result<Pool, Error> {
let num_conns = self.get_num_conns();
let opens = (0..num_conns).map(|_| {
ClientBuilder {
path: self.path.clone(),
flags: self.flags,
journal_mode: self.journal_mode,
vfs: self.vfs.clone(),
}
.open()
});
let clients = join_all(opens)
.await
.into_iter()
.collect::<Result<Vec<Client>, Error>>()?;
Ok(Pool {
state: Arc::new(State {
clients,
counter: AtomicU64::new(0),
}),
})
}
pub fn open_blocking(self) -> Result<Pool, Error> {
let num_conns = self.get_num_conns();
let clients = (0..num_conns)
.map(|_| {
ClientBuilder {
path: self.path.clone(),
flags: self.flags,
journal_mode: self.journal_mode,
vfs: self.vfs.clone(),
}
.open_blocking()
})
.collect::<Result<Vec<Client>, Error>>()?;
Ok(Pool {
state: Arc::new(State {
clients,
counter: AtomicU64::new(0),
}),
})
}
fn get_num_conns(&self) -> usize {
self.num_conns.unwrap_or_else(|| {
available_parallelism()
.unwrap_or_else(|_| NonZeroUsize::new(1).unwrap())
.into()
})
}
}
#[derive(Clone)]
pub struct Pool {
state: Arc<State>,
}
struct State {
clients: Vec<Client>,
counter: AtomicU64,
}
impl Pool {
pub async fn conn<F, T>(&self, func: F) -> Result<T, Error>
where
F: FnOnce(&Connection) -> Result<T, rusqlite::Error> + Send + 'static,
T: Send + 'static,
{
self.get().conn(func).await
}
pub async fn conn_mut<F, T>(&self, func: F) -> Result<T, Error>
where
F: FnOnce(&mut Connection) -> Result<T, rusqlite::Error> + Send + 'static,
T: Send + 'static,
{
self.get().conn_mut(func).await
}
pub async fn close(&self) -> Result<(), Error> {
let closes = self.state.clients.iter().map(|client| client.close());
let res = join_all(closes).await;
res.into_iter().collect::<Result<Vec<_>, Error>>()?;
Ok(())
}
pub fn conn_blocking<F, T>(&self, func: F) -> Result<T, Error>
where
F: FnOnce(&Connection) -> Result<T, rusqlite::Error> + Send + 'static,
T: Send + 'static,
{
self.get().conn_blocking(func)
}
pub fn conn_mut_blocking<F, T>(&self, func: F) -> Result<T, Error>
where
F: FnOnce(&mut Connection) -> Result<T, rusqlite::Error> + Send + 'static,
T: Send + 'static,
{
self.get().conn_mut_blocking(func)
}
pub fn close_blocking(&self) -> Result<(), Error> {
self.state
.clients
.iter()
.try_for_each(|client| client.close_blocking())
}
fn get(&self) -> &Client {
let n = self.state.counter.fetch_add(1, Relaxed);
&self.state.clients[n as usize % self.state.clients.len()]
}
pub async fn conn_for_each<F, T>(&self, func: F) -> Vec<Result<T, Error>>
where
F: Fn(&Connection) -> Result<T, rusqlite::Error> + Send + Sync + 'static,
T: Send + 'static,
{
let func = Arc::new(func);
let futures = self.state.clients.iter().map(|client| {
let func = func.clone();
async move { client.conn(move |conn| func(conn)).await }
});
join_all(futures).await
}
pub fn conn_for_each_blocking<F, T>(&self, func: F) -> Vec<Result<T, Error>>
where
F: Fn(&Connection) -> Result<T, rusqlite::Error> + Send + Sync + 'static,
T: Send + 'static,
{
let func = Arc::new(func);
self.state
.clients
.iter()
.map(|client| {
let func = func.clone();
client.conn_blocking(move |conn| func(conn))
})
.collect()
}
}