1use std::{
2 path::{Path, PathBuf},
3 sync::{
4 Arc,
5 atomic::{AtomicU32, Ordering::Relaxed},
6 },
7 thread::available_parallelism,
8};
9
10use crate::{Client, ClientBuilder, Error};
11
12use duckdb::{Config, Connection};
13use futures_util::future::join_all;
14
15#[derive(Clone, Debug, Default)]
34pub struct PoolBuilder {
35 pub(crate) path: Option<PathBuf>,
36 pub(crate) flagsfn: Option<fn() -> duckdb::Result<Config>>,
37 pub(crate) num_conns: Option<usize>,
38}
39
40impl PoolBuilder {
41 #[must_use]
43 pub fn new() -> Self {
44 Self::default()
45 }
46
47 #[must_use]
51 pub fn path<P: AsRef<Path>>(mut self, path: P) -> Self {
52 self.path = Some(path.as_ref().into());
53 if self.flagsfn.is_none() {
54 let cfg_fn = || Config::default().access_mode(duckdb::AccessMode::ReadOnly);
55 self.flagsfn = Some(cfg_fn);
56 }
57 self
58 }
59
60 #[must_use]
64 pub fn flagsfn(mut self, flags: fn() -> duckdb::Result<Config>) -> Self {
65 self.flagsfn = Some(flags);
66 self
67 }
68
69 #[must_use]
73 pub fn num_conns(mut self, num_conns: usize) -> Self {
74 self.num_conns = Some(num_conns);
75 self
76 }
77
78 pub async fn open(self) -> Result<Pool, Error> {
90 let num_conns = self.get_num_conns();
91 let opens = (0..num_conns).map(|_| {
92 ClientBuilder {
93 path: self.path.clone(),
94 flagsfn: self.flagsfn,
95 }
96 .open()
97 });
98 let clients = join_all(opens)
99 .await
100 .into_iter()
101 .collect::<Result<Vec<Client>, Error>>()?;
102 Ok(Pool {
103 state: Arc::new(State {
104 clients,
105 counter: AtomicU32::new(0),
106 }),
107 })
108 }
109
110 pub fn open_blocking(self) -> Result<Pool, Error> {
123 let num_conns = self.get_num_conns();
124 let clients = (0..num_conns)
125 .map(|_| {
126 ClientBuilder {
127 path: self.path.clone(),
128 flagsfn: self.flagsfn,
129 }
130 .open_blocking()
131 })
132 .collect::<Result<Vec<Client>, Error>>()?;
133 Ok(Pool {
134 state: Arc::new(State {
135 clients,
136 counter: AtomicU32::new(0),
137 }),
138 })
139 }
140
141 fn get_num_conns(&self) -> usize {
142 self.num_conns.unwrap_or_else(|| {
143 match available_parallelism() {
144 Ok(n) => n.get(),
145 Err(_) => 1,
146 }
147
148 })
154 }
155}
156
157#[derive(Clone)]
161pub struct Pool {
162 state: Arc<State>,
163}
164
165struct State {
166 clients: Vec<Client>,
167 counter: AtomicU32,
168}
169
170impl Pool {
171 pub async fn conn<F, T>(&self, func: F) -> Result<T, Error>
173 where
174 F: FnOnce(&Connection) -> Result<T, duckdb::Error> + Send + 'static,
175 T: Send + 'static,
176 {
177 self.get().conn(func).await
178 }
179
180 pub async fn conn_mut<F, T>(&self, func: F) -> Result<T, Error>
182 where
183 F: FnOnce(&mut Connection) -> Result<T, duckdb::Error> + Send + 'static,
184 T: Send + 'static,
185 {
186 self.get().conn_mut(func).await
187 }
188
189 pub async fn close(&self) -> Result<(), Error> {
194 for client in &self.state.clients {
195 client.close().await?;
196 }
197 Ok(())
198 }
199
200 pub fn conn_blocking<F, T>(&self, func: F) -> Result<T, Error>
203 where
204 F: FnOnce(&Connection) -> Result<T, duckdb::Error> + Send + 'static,
205 T: Send + 'static,
206 {
207 self.get().conn_blocking(func)
208 }
209
210 pub fn conn_mut_blocking<F, T>(&self, func: F) -> Result<T, Error>
213 where
214 F: FnOnce(&mut Connection) -> Result<T, duckdb::Error> + Send + 'static,
215 T: Send + 'static,
216 {
217 self.get().conn_mut_blocking(func)
218 }
219
220 pub fn close_blocking(&self) -> Result<(), Error> {
225 self.state
226 .clients
227 .iter()
228 .try_for_each(super::client::Client::close_blocking)
229 }
230
231 fn get(&self) -> &Client {
232 let n = self.state.counter.fetch_add(1, Relaxed);
233 &self.state.clients[n as usize % self.state.clients.len()]
234 }
235
236 pub async fn conn_for_each<F, T>(&self, func: F) -> Vec<Result<T, Error>>
240 where
241 F: Fn(&Connection) -> Result<T, duckdb::Error> + Send + Sync + 'static,
242 T: Send + 'static,
243 {
244 let func = Arc::new(func);
245 let futures = self.state.clients.iter().map(|client| {
246 let func = func.clone();
247 async move { client.conn(move |conn| func(conn)).await }
248 });
249 join_all(futures).await
250 }
251
252 pub fn conn_for_each_blocking<F, T>(&self, func: F) -> Vec<Result<T, Error>>
254 where
255 F: Fn(&Connection) -> Result<T, duckdb::Error> + Send + Sync + 'static,
256 T: Send + 'static,
257 {
258 let func = Arc::new(func);
259 self.state
260 .clients
261 .iter()
262 .map(|client| {
263 let func = func.clone();
264 client.conn_blocking(move |conn| func(conn))
265 })
266 .collect()
267 }
268}