use std::{
path::{Path, PathBuf},
sync::{
Arc,
atomic::{AtomicU32, Ordering::Relaxed},
},
thread::available_parallelism,
};
use crate::{Client, ClientBuilder, Error};
use duckdb::{Config, Connection};
use futures_util::future::join_all;
#[derive(Clone, Debug, Default)]
pub struct PoolBuilder {
pub(crate) path: Option<PathBuf>,
pub(crate) flagsfn: Option<fn() -> duckdb::Result<Config>>,
pub(crate) num_conns: Option<usize>,
}
impl PoolBuilder {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn path<P: AsRef<Path>>(mut self, path: P) -> Self {
self.path = Some(path.as_ref().into());
if self.flagsfn.is_none() {
let cfg_fn = || Config::default().access_mode(duckdb::AccessMode::ReadOnly);
self.flagsfn = Some(cfg_fn);
}
self
}
#[must_use]
pub fn flagsfn(mut self, flags: fn() -> duckdb::Result<Config>) -> Self {
self.flagsfn = Some(flags);
self
}
#[must_use]
pub fn num_conns(mut self, num_conns: usize) -> Self {
self.num_conns = Some(num_conns);
self
}
pub async fn open(self) -> Result<Pool, Error> {
let num_conns = self.get_num_conns();
let opens = (0..num_conns).map(|_| {
ClientBuilder {
path: self.path.clone(),
flagsfn: self.flagsfn,
}
.open()
});
let clients = join_all(opens)
.await
.into_iter()
.collect::<Result<Vec<Client>, Error>>()?;
Ok(Pool {
state: Arc::new(State {
clients,
counter: AtomicU32::new(0),
}),
})
}
pub fn open_blocking(self) -> Result<Pool, Error> {
let num_conns = self.get_num_conns();
let clients = (0..num_conns)
.map(|_| {
ClientBuilder {
path: self.path.clone(),
flagsfn: self.flagsfn,
}
.open_blocking()
})
.collect::<Result<Vec<Client>, Error>>()?;
Ok(Pool {
state: Arc::new(State {
clients,
counter: AtomicU32::new(0),
}),
})
}
fn get_num_conns(&self) -> usize {
self.num_conns.unwrap_or_else(|| {
match available_parallelism() {
Ok(n) => n.get(),
Err(_) => 1,
}
})
}
}
#[derive(Clone)]
pub struct Pool {
state: Arc<State>,
}
struct State {
clients: Vec<Client>,
counter: AtomicU32,
}
impl Pool {
pub async fn conn<F, T>(&self, func: F) -> Result<T, Error>
where
F: FnOnce(&Connection) -> Result<T, duckdb::Error> + Send + 'static,
T: Send + 'static,
{
self.get().conn(func).await
}
pub async fn conn_mut<F, T>(&self, func: F) -> Result<T, Error>
where
F: FnOnce(&mut Connection) -> Result<T, duckdb::Error> + Send + 'static,
T: Send + 'static,
{
self.get().conn_mut(func).await
}
pub async fn close(&self) -> Result<(), Error> {
for client in &self.state.clients {
client.close().await?;
}
Ok(())
}
pub fn conn_blocking<F, T>(&self, func: F) -> Result<T, Error>
where
F: FnOnce(&Connection) -> Result<T, duckdb::Error> + Send + 'static,
T: Send + 'static,
{
self.get().conn_blocking(func)
}
pub fn conn_mut_blocking<F, T>(&self, func: F) -> Result<T, Error>
where
F: FnOnce(&mut Connection) -> Result<T, duckdb::Error> + Send + 'static,
T: Send + 'static,
{
self.get().conn_mut_blocking(func)
}
pub fn close_blocking(&self) -> Result<(), Error> {
self.state
.clients
.iter()
.try_for_each(super::client::Client::close_blocking)
}
fn get(&self) -> &Client {
let n = self.state.counter.fetch_add(1, Relaxed);
&self.state.clients[n as usize % self.state.clients.len()]
}
pub async fn conn_for_each<F, T>(&self, func: F) -> Vec<Result<T, Error>>
where
F: Fn(&Connection) -> Result<T, duckdb::Error> + Send + Sync + 'static,
T: Send + 'static,
{
let func = Arc::new(func);
let futures = self.state.clients.iter().map(|client| {
let func = func.clone();
async move { client.conn(move |conn| func(conn)).await }
});
join_all(futures).await
}
pub fn conn_for_each_blocking<F, T>(&self, func: F) -> Vec<Result<T, Error>>
where
F: Fn(&Connection) -> Result<T, duckdb::Error> + Send + Sync + 'static,
T: Send + 'static,
{
let func = Arc::new(func);
self.state
.clients
.iter()
.map(|client| {
let func = func.clone();
client.conn_blocking(move |conn| func(conn))
})
.collect()
}
}