async_duckdb/
pool.rs

1use std::{
2    path::{Path, PathBuf},
3    sync::{
4        Arc,
5        atomic::{AtomicU32, Ordering::Relaxed},
6    },
7    thread::available_parallelism,
8};
9
10use crate::{Client, ClientBuilder, Error};
11
12use duckdb::{Config, Connection};
13use futures_util::future::join_all;
14
15/// A `PoolBuilder` can be used to create a [`Pool`] with custom
16/// configuration.
17///
18/// See [`Client`] for more information.
19///
20/// # Examples
21///
22/// ```rust
23/// # use async_duckdb::PoolBuilder;
24/// # async fn run() -> Result<(), async_duckdb::Error> {
25/// let pool = PoolBuilder::new().path("path/to/db.duck").open().await?;
26///
27/// // ...
28///
29/// pool.close().await?;
30/// # Ok(())
31/// # }
32/// ```
33#[derive(Clone, Debug, Default)]
34pub struct PoolBuilder {
35    pub(crate) path: Option<PathBuf>,
36    pub(crate) flagsfn: Option<fn() -> duckdb::Result<Config>>,
37    pub(crate) num_conns: Option<usize>,
38}
39
40impl PoolBuilder {
41    /// Returns a new [`PoolBuilder`] with the default settings.
42    #[must_use]
43    pub fn new() -> Self {
44        Self::default()
45    }
46
47    /// Specify the path of the duckdb database to open.
48    ///
49    /// By default, an in-memory database is used.
50    #[must_use]
51    pub fn path<P: AsRef<Path>>(mut self, path: P) -> Self {
52        self.path = Some(path.as_ref().into());
53        if self.flagsfn.is_none() {
54            let cfg_fn = || Config::default().access_mode(duckdb::AccessMode::ReadOnly);
55            self.flagsfn = Some(cfg_fn);
56        }
57        self
58    }
59
60    /// Specify the [`OpenFlags`] to use when opening a new connection.
61    ///
62    /// By default, [`OpenFlags::default()`] is used.
63    #[must_use]
64    pub fn flagsfn(mut self, flags: fn() -> duckdb::Result<Config>) -> Self {
65        self.flagsfn = Some(flags);
66        self
67    }
68
69    /// Specify the number of duckdb connections to open as part of the pool.
70    ///
71    /// Defaults to the number of logical CPUs of the current system.
72    #[must_use]
73    pub fn num_conns(mut self, num_conns: usize) -> Self {
74        self.num_conns = Some(num_conns);
75        self
76    }
77
78    /// Returns a new [`Pool`] that uses the `PoolBuilder` configuration.
79    ///
80    /// # Examples
81    ///
82    /// ```rust
83    /// # use async_duckdb::PoolBuilder;
84    /// # async fn run() -> Result<(), async_duckdb::Error> {
85    /// let pool = PoolBuilder::new().open().await?;
86    /// # Ok(())
87    /// # }
88    /// ```
89    pub async fn open(self) -> Result<Pool, Error> {
90        let num_conns = self.get_num_conns();
91        let opens = (0..num_conns).map(|_| {
92            ClientBuilder {
93                path: self.path.clone(),
94                flagsfn: self.flagsfn,
95            }
96            .open()
97        });
98        let clients = join_all(opens)
99            .await
100            .into_iter()
101            .collect::<Result<Vec<Client>, Error>>()?;
102        Ok(Pool {
103            state: Arc::new(State {
104                clients,
105                counter: AtomicU32::new(0),
106            }),
107        })
108    }
109
110    /// Returns a new [`Pool`] that uses the `PoolBuilder` configuration,
111    /// blocking the current thread.
112    ///
113    /// # Examples
114    ///
115    /// ```rust
116    /// # use async_duckdb::PoolBuilder;
117    /// # fn run() -> Result<(), async_duckdb::Error> {
118    /// let pool = PoolBuilder::new().open_blocking()?;
119    /// # Ok(())
120    /// # }
121    /// ```
122    pub fn open_blocking(self) -> Result<Pool, Error> {
123        let num_conns = self.get_num_conns();
124        let clients = (0..num_conns)
125            .map(|_| {
126                ClientBuilder {
127                    path: self.path.clone(),
128                    flagsfn: self.flagsfn,
129                }
130                .open_blocking()
131            })
132            .collect::<Result<Vec<Client>, Error>>()?;
133        Ok(Pool {
134            state: Arc::new(State {
135                clients,
136                counter: AtomicU32::new(0),
137            }),
138        })
139    }
140
141    fn get_num_conns(&self) -> usize {
142        self.num_conns.unwrap_or_else(|| {
143            match available_parallelism() {
144                Ok(n) => n.get(),
145                Err(_) => 1,
146            }
147
148            // if let Err(e)  = available_parallelism() {
149            //     1
150            // available_parallelism()
151            //     .unwrap_or_else(|_| NonZeroUsize::new(1).unwrap())
152            //     .into()
153        })
154    }
155}
156
157/// A simple Pool of duckdb connections.
158///
159/// A Pool has the same API as an individual [`Client`].
160#[derive(Clone)]
161pub struct Pool {
162    state: Arc<State>,
163}
164
165struct State {
166    clients: Vec<Client>,
167    counter: AtomicU32,
168}
169
170impl Pool {
171    /// Invokes the provided function with a [`duckdb::Connection`].
172    pub async fn conn<F, T>(&self, func: F) -> Result<T, Error>
173    where
174        F: FnOnce(&Connection) -> Result<T, duckdb::Error> + Send + 'static,
175        T: Send + 'static,
176    {
177        self.get().conn(func).await
178    }
179
180    /// Invokes the provided function with a mutable [`duckdb::Connection`].
181    pub async fn conn_mut<F, T>(&self, func: F) -> Result<T, Error>
182    where
183        F: FnOnce(&mut Connection) -> Result<T, duckdb::Error> + Send + 'static,
184        T: Send + 'static,
185    {
186        self.get().conn_mut(func).await
187    }
188
189    /// Closes the underlying duckdb connections.
190    ///
191    /// After this method returns, all calls to `self::conn()` or
192    /// `self::conn_mut()` will return an [`Error::Closed`] error.
193    pub async fn close(&self) -> Result<(), Error> {
194        for client in &self.state.clients {
195            client.close().await?;
196        }
197        Ok(())
198    }
199
200    /// Invokes the provided function with a [`duckdb::Connection`], blocking
201    /// the current thread.
202    pub fn conn_blocking<F, T>(&self, func: F) -> Result<T, Error>
203    where
204        F: FnOnce(&Connection) -> Result<T, duckdb::Error> + Send + 'static,
205        T: Send + 'static,
206    {
207        self.get().conn_blocking(func)
208    }
209
210    /// Invokes the provided function with a mutable [`duckdb::Connection`],
211    /// blocking the current thread.
212    pub fn conn_mut_blocking<F, T>(&self, func: F) -> Result<T, Error>
213    where
214        F: FnOnce(&mut Connection) -> Result<T, duckdb::Error> + Send + 'static,
215        T: Send + 'static,
216    {
217        self.get().conn_mut_blocking(func)
218    }
219
220    /// Closes the underlying duckdb connections, blocking the current thread.
221    ///
222    /// After this method returns, all calls to `self::conn_blocking()` or
223    /// `self::conn_mut_blocking()` will return an [`Error::Closed`] error.
224    pub fn close_blocking(&self) -> Result<(), Error> {
225        self.state
226            .clients
227            .iter()
228            .try_for_each(super::client::Client::close_blocking)
229    }
230
231    fn get(&self) -> &Client {
232        let n = self.state.counter.fetch_add(1, Relaxed);
233        &self.state.clients[n as usize % self.state.clients.len()]
234    }
235
236    /// Runs a function on all connections in the pool asynchronously.
237    ///
238    /// The function is executed on each connection concurrently.
239    pub async fn conn_for_each<F, T>(&self, func: F) -> Vec<Result<T, Error>>
240    where
241        F: Fn(&Connection) -> Result<T, duckdb::Error> + Send + Sync + 'static,
242        T: Send + 'static,
243    {
244        let func = Arc::new(func);
245        let futures = self.state.clients.iter().map(|client| {
246            let func = func.clone();
247            async move { client.conn(move |conn| func(conn)).await }
248        });
249        join_all(futures).await
250    }
251
252    /// Runs a function on all connections in the pool, blocking the current thread.
253    pub fn conn_for_each_blocking<F, T>(&self, func: F) -> Vec<Result<T, Error>>
254    where
255        F: Fn(&Connection) -> Result<T, duckdb::Error> + Send + Sync + 'static,
256        T: Send + 'static,
257    {
258        let func = Arc::new(func);
259        self.state
260            .clients
261            .iter()
262            .map(|client| {
263                let func = func.clone();
264                client.conn_blocking(move |conn| func(conn))
265            })
266            .collect()
267    }
268}