use std::{
num::NonZeroUsize,
path::{Path, PathBuf},
sync::{
atomic::{AtomicU64, Ordering::Relaxed},
Arc,
},
thread::available_parallelism,
};
use crate::{Client, ClientBuilder, Error, JournalMode};
use futures_util::future::join_all;
use rusqlite::{Connection, OpenFlags};
#[derive(Clone, Debug, Default)]
pub struct PoolBuilder {
path: Option<PathBuf>,
shared_memory_name: Option<String>,
flags: OpenFlags,
journal_mode: Option<JournalMode>,
vfs: Option<String>,
num_conns: Option<usize>,
queue_capacity: Option<usize>,
}
impl PoolBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn path<P: AsRef<Path>>(mut self, path: P) -> Self {
self.path = Some(path.as_ref().into());
self.shared_memory_name = None;
self
}
pub fn shared_memory<N: AsRef<str>>(mut self, name: N) -> Self {
self.path = None;
self.shared_memory_name = Some(name.as_ref().to_owned());
self
}
pub fn flags(mut self, flags: OpenFlags) -> Self {
self.flags = flags;
self
}
pub fn journal_mode(mut self, journal_mode: JournalMode) -> Self {
self.journal_mode = Some(journal_mode);
self
}
pub fn vfs(mut self, vfs: &str) -> Self {
self.vfs = Some(vfs.to_owned());
self
}
pub fn num_conns(mut self, num_conns: usize) -> Self {
self.num_conns = Some(num_conns.max(1));
self
}
pub fn queue_capacity(mut self, queue_capacity: usize) -> Self {
self.queue_capacity = Some(queue_capacity);
self
}
pub async fn open(self) -> Result<Pool, Error> {
let num_conns = self.get_num_conns();
self.validate(num_conns)?;
let first = self.client_builder().open().await?;
let opens = (1..num_conns).map(|_| self.client_builder().open());
let mut clients = vec![first];
clients.extend(
join_all(opens)
.await
.into_iter()
.collect::<Result<Vec<Client>, Error>>()?,
);
Ok(Pool {
state: Arc::new(State {
clients,
counter: AtomicU64::new(0),
}),
})
}
pub fn open_blocking(self) -> Result<Pool, Error> {
let num_conns = self.get_num_conns();
self.validate(num_conns)?;
let first = self.client_builder().open_blocking()?;
let mut clients = vec![first];
clients.extend(
(1..num_conns)
.map(|_| self.client_builder().open_blocking())
.collect::<Result<Vec<Client>, Error>>()?,
);
Ok(Pool {
state: Arc::new(State {
clients,
counter: AtomicU64::new(0),
}),
})
}
fn get_num_conns(&self) -> usize {
if let Some(num_conns) = self.num_conns {
return num_conns;
}
if self.is_anonymous_memory() {
return 1;
}
available_parallelism()
.unwrap_or_else(|_| NonZeroUsize::new(1).unwrap())
.into()
}
fn validate(&self, num_conns: usize) -> Result<(), Error> {
if self
.shared_memory_name
.as_ref()
.is_some_and(|name| name.is_empty())
{
return Err(Error::Config {
message: "shared memory database name must not be empty",
});
}
if self.is_anonymous_memory() && num_conns > 1 {
return Err(Error::Config {
message: "anonymous in-memory pools cannot use multiple connections; call path(...) for file-backed pools or shared_memory(...) for named shared in-memory pools",
});
}
Ok(())
}
fn client_builder(&self) -> ClientBuilder {
ClientBuilder {
path: self.connection_path(),
flags: self.connection_flags(),
journal_mode: self.journal_mode,
vfs: self.vfs.clone(),
queue_capacity: self.queue_capacity,
}
}
fn connection_path(&self) -> Option<PathBuf> {
self.shared_memory_name
.as_deref()
.map(shared_memory_uri)
.or_else(|| self.path.clone())
}
fn connection_flags(&self) -> OpenFlags {
let mut flags = self.flags;
if self.shared_memory_name.is_some() {
flags.insert(OpenFlags::SQLITE_OPEN_URI);
flags.insert(OpenFlags::SQLITE_OPEN_SHARED_CACHE);
flags.remove(OpenFlags::SQLITE_OPEN_PRIVATE_CACHE);
}
flags
}
fn is_anonymous_memory(&self) -> bool {
self.shared_memory_name.is_none()
&& self
.path
.as_deref()
.is_none_or(|path| path == Path::new(":memory:"))
}
}
fn shared_memory_uri(name: &str) -> PathBuf {
let mut uri = String::from("file:");
push_uri_encoded(name, &mut uri);
uri.push_str("?mode=memory&cache=shared");
uri.into()
}
fn push_uri_encoded(input: &str, out: &mut String) {
const HEX: &[u8; 16] = b"0123456789ABCDEF";
for byte in input.bytes() {
match byte {
b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'.' | b'_' | b'~' => {
out.push(byte.into());
}
_ => {
out.push('%');
out.push(HEX[(byte >> 4) as usize].into());
out.push(HEX[(byte & 0x0F) as usize].into());
}
}
}
}
#[derive(Clone)]
pub struct Pool {
state: Arc<State>,
}
struct State {
clients: Vec<Client>,
counter: AtomicU64,
}
impl Pool {
pub async fn conn<F, T>(&self, func: F) -> Result<T, Error>
where
F: FnOnce(&Connection) -> Result<T, rusqlite::Error> + Send + 'static,
T: Send + 'static,
{
self.get().conn(func).await
}
pub async fn conn_mut<F, T>(&self, func: F) -> Result<T, Error>
where
F: FnOnce(&mut Connection) -> Result<T, rusqlite::Error> + Send + 'static,
T: Send + 'static,
{
self.get().conn_mut(func).await
}
pub async fn conn_and_then<F, T, E>(&self, func: F) -> Result<T, E>
where
F: FnOnce(&Connection) -> Result<T, E> + Send + 'static,
T: Send + 'static,
E: From<rusqlite::Error> + From<Error> + Send + 'static,
{
self.get().conn_and_then(func).await
}
pub async fn conn_mut_and_then<F, T, E>(&self, func: F) -> Result<T, E>
where
F: FnOnce(&mut Connection) -> Result<T, E> + Send + 'static,
T: Send + 'static,
E: From<rusqlite::Error> + From<Error> + Send + 'static,
{
self.get().conn_mut_and_then(func).await
}
pub async fn close(&self) -> Result<(), Error> {
let closes = self.state.clients.iter().map(|client| client.close());
let res = join_all(closes).await;
res.into_iter().collect::<Result<Vec<_>, Error>>()?;
Ok(())
}
pub fn conn_blocking<F, T>(&self, func: F) -> Result<T, Error>
where
F: FnOnce(&Connection) -> Result<T, rusqlite::Error> + Send + 'static,
T: Send + 'static,
{
self.get().conn_blocking(func)
}
pub fn conn_mut_blocking<F, T>(&self, func: F) -> Result<T, Error>
where
F: FnOnce(&mut Connection) -> Result<T, rusqlite::Error> + Send + 'static,
T: Send + 'static,
{
self.get().conn_mut_blocking(func)
}
pub fn conn_and_then_blocking<F, T, E>(&self, func: F) -> Result<T, E>
where
F: FnOnce(&Connection) -> Result<T, E> + Send + 'static,
T: Send + 'static,
E: From<rusqlite::Error> + From<Error> + Send + 'static,
{
self.get().conn_and_then_blocking(func)
}
pub fn conn_mut_and_then_blocking<F, T, E>(&self, func: F) -> Result<T, E>
where
F: FnOnce(&mut Connection) -> Result<T, E> + Send + 'static,
T: Send + 'static,
E: From<rusqlite::Error> + From<Error> + Send + 'static,
{
self.get().conn_mut_and_then_blocking(func)
}
pub fn close_blocking(&self) -> Result<(), Error> {
let mut first_err = None;
for client in self.state.clients.iter() {
if let Err(e) = client.close_blocking() {
if first_err.is_none() {
first_err = Some(e);
}
}
}
match first_err {
Some(e) => Err(e),
None => Ok(()),
}
}
fn get(&self) -> &Client {
let n = self.state.counter.fetch_add(1, Relaxed);
&self.state.clients[n as usize % self.state.clients.len()]
}
pub async fn conn_for_each<F, T>(&self, func: F) -> Vec<Result<T, Error>>
where
F: Fn(&Connection) -> Result<T, rusqlite::Error> + Send + Sync + 'static,
T: Send + 'static,
{
let func = Arc::new(func);
let futures = self.state.clients.iter().map(|client| {
let func = func.clone();
async move { client.conn(move |conn| func(conn)).await }
});
join_all(futures).await
}
pub fn conn_for_each_blocking<F, T>(&self, func: F) -> Vec<Result<T, Error>>
where
F: Fn(&Connection) -> Result<T, rusqlite::Error> + Send + Sync + 'static,
T: Send + 'static,
{
let func = Arc::new(func);
self.state
.clients
.iter()
.map(|client| {
let func = func.clone();
client.conn_blocking(move |conn| func(conn))
})
.collect()
}
}