Skip to main content

async_sqlite/
pool.rs

1use std::{
2    num::NonZeroUsize,
3    path::{Path, PathBuf},
4    sync::{
5        atomic::{AtomicU64, Ordering::Relaxed},
6        Arc,
7    },
8    thread::available_parallelism,
9};
10
11use crate::{Client, ClientBuilder, Error, JournalMode};
12
13use futures_util::future::join_all;
14use rusqlite::{Connection, OpenFlags};
15
16/// A `PoolBuilder` can be used to create a [`Pool`] with custom
17/// configuration.
18///
19/// See [`Client`] for more information.
20///
21/// # Examples
22///
23/// ```rust
24/// # use async_sqlite::PoolBuilder;
25/// # async fn run() -> Result<(), async_sqlite::Error> {
26/// let pool = PoolBuilder::new().path("path/to/db.sqlite3").open().await?;
27///
28/// // ...
29///
30/// pool.close().await?;
31/// # Ok(())
32/// # }
33/// ```
34#[derive(Clone, Debug, Default)]
35pub struct PoolBuilder {
36    path: Option<PathBuf>,
37    flags: OpenFlags,
38    journal_mode: Option<JournalMode>,
39    vfs: Option<String>,
40    num_conns: Option<usize>,
41}
42
43impl PoolBuilder {
44    /// Returns a new [`PoolBuilder`] with the default settings.
45    pub fn new() -> Self {
46        Self::default()
47    }
48
49    /// Specify the path of the sqlite3 database to open.
50    ///
51    /// By default, an in-memory database is used.
52    pub fn path<P: AsRef<Path>>(mut self, path: P) -> Self {
53        self.path = Some(path.as_ref().into());
54        self
55    }
56
57    /// Specify the [`OpenFlags`] to use when opening a new connection.
58    ///
59    /// By default, [`OpenFlags::default()`] is used.
60    pub fn flags(mut self, flags: OpenFlags) -> Self {
61        self.flags = flags;
62        self
63    }
64
65    /// Specify the [`JournalMode`] to set when opening a new connection.
66    ///
67    /// By default, no `journal_mode` is explicity set.
68    pub fn journal_mode(mut self, journal_mode: JournalMode) -> Self {
69        self.journal_mode = Some(journal_mode);
70        self
71    }
72
73    /// Specify the name of the [vfs](https://www.sqlite.org/vfs.html) to use.
74    pub fn vfs(mut self, vfs: &str) -> Self {
75        self.vfs = Some(vfs.to_owned());
76        self
77    }
78
79    /// Specify the number of sqlite connections to open as part of the pool.
80    ///
81    /// Defaults to the number of logical CPUs of the current system. Values
82    /// less than `1` are clamped to `1`.
83    ///
84    /// ```
85    /// use async_sqlite::PoolBuilder;
86    ///
87    /// let builder = PoolBuilder::new().num_conns(2);
88    /// ```
89    pub fn num_conns(mut self, num_conns: usize) -> Self {
90        self.num_conns = Some(num_conns.max(1));
91        self
92    }
93
94    /// Returns a new [`Pool`] that uses the `PoolBuilder` configuration.
95    ///
96    /// # Examples
97    ///
98    /// ```rust
99    /// # use async_sqlite::PoolBuilder;
100    /// # async fn run() -> Result<(), async_sqlite::Error> {
101    /// let pool = PoolBuilder::new().open().await?;
102    /// # Ok(())
103    /// # }
104    /// ```
105    pub async fn open(self) -> Result<Pool, Error> {
106        let num_conns = self.get_num_conns();
107
108        // Open the first connection with full config (including journal_mode).
109        // This must complete before opening remaining connections to avoid
110        // concurrent PRAGMA writes on a new database file.
111        let first = ClientBuilder {
112            path: self.path.clone(),
113            flags: self.flags,
114            journal_mode: self.journal_mode,
115            vfs: self.vfs.clone(),
116        }
117        .open()
118        .await?;
119
120        // Open remaining connections without journal_mode since it's a
121        // database-level setting already applied by the first connection.
122        let opens = (1..num_conns).map(|_| {
123            ClientBuilder {
124                path: self.path.clone(),
125                flags: self.flags,
126                journal_mode: None,
127                vfs: self.vfs.clone(),
128            }
129            .open()
130        });
131        let mut clients = vec![first];
132        clients.extend(
133            join_all(opens)
134                .await
135                .into_iter()
136                .collect::<Result<Vec<Client>, Error>>()?,
137        );
138
139        Ok(Pool {
140            state: Arc::new(State {
141                clients,
142                counter: AtomicU64::new(0),
143            }),
144        })
145    }
146
147    /// Returns a new [`Pool`] that uses the `PoolBuilder` configuration,
148    /// blocking the current thread.
149    ///
150    /// # Examples
151    ///
152    /// ```rust
153    /// # use async_sqlite::PoolBuilder;
154    /// # fn run() -> Result<(), async_sqlite::Error> {
155    /// let pool = PoolBuilder::new().open_blocking()?;
156    /// # Ok(())
157    /// # }
158    /// ```
159    pub fn open_blocking(self) -> Result<Pool, Error> {
160        let num_conns = self.get_num_conns();
161
162        // Open the first connection with full config (including journal_mode).
163        let first = ClientBuilder {
164            path: self.path.clone(),
165            flags: self.flags,
166            journal_mode: self.journal_mode,
167            vfs: self.vfs.clone(),
168        }
169        .open_blocking()?;
170
171        // Open remaining connections without journal_mode since it's a
172        // database-level setting already applied by the first connection.
173        let mut clients = vec![first];
174        clients.extend(
175            (1..num_conns)
176                .map(|_| {
177                    ClientBuilder {
178                        path: self.path.clone(),
179                        flags: self.flags,
180                        journal_mode: None,
181                        vfs: self.vfs.clone(),
182                    }
183                    .open_blocking()
184                })
185                .collect::<Result<Vec<Client>, Error>>()?,
186        );
187
188        Ok(Pool {
189            state: Arc::new(State {
190                clients,
191                counter: AtomicU64::new(0),
192            }),
193        })
194    }
195
196    fn get_num_conns(&self) -> usize {
197        self.num_conns.unwrap_or_else(|| {
198            available_parallelism()
199                .unwrap_or_else(|_| NonZeroUsize::new(1).unwrap())
200                .into()
201        })
202    }
203}
204
205/// A simple Pool of sqlite connections.
206///
207/// A Pool has the same API as an individual [`Client`].
208#[derive(Clone)]
209pub struct Pool {
210    state: Arc<State>,
211}
212
213struct State {
214    clients: Vec<Client>,
215    counter: AtomicU64,
216}
217
218impl Pool {
219    /// Invokes the provided function with a [`rusqlite::Connection`].
220    pub async fn conn<F, T>(&self, func: F) -> Result<T, Error>
221    where
222        F: FnOnce(&Connection) -> Result<T, rusqlite::Error> + Send + 'static,
223        T: Send + 'static,
224    {
225        self.get().conn(func).await
226    }
227
228    /// Invokes the provided function with a mutable [`rusqlite::Connection`].
229    pub async fn conn_mut<F, T>(&self, func: F) -> Result<T, Error>
230    where
231        F: FnOnce(&mut Connection) -> Result<T, rusqlite::Error> + Send + 'static,
232        T: Send + 'static,
233    {
234        self.get().conn_mut(func).await
235    }
236
237    /// Closes the underlying sqlite connections.
238    ///
239    /// After this method returns, all calls to `self::conn()` or
240    /// `self::conn_mut()` will return an [`Error::Closed`] error.
241    pub async fn close(&self) -> Result<(), Error> {
242        let closes = self.state.clients.iter().map(|client| client.close());
243        let res = join_all(closes).await;
244        res.into_iter().collect::<Result<Vec<_>, Error>>()?;
245        Ok(())
246    }
247
248    /// Invokes the provided function with a [`rusqlite::Connection`], blocking
249    /// the current thread.
250    pub fn conn_blocking<F, T>(&self, func: F) -> Result<T, Error>
251    where
252        F: FnOnce(&Connection) -> Result<T, rusqlite::Error> + Send + 'static,
253        T: Send + 'static,
254    {
255        self.get().conn_blocking(func)
256    }
257
258    /// Invokes the provided function with a mutable [`rusqlite::Connection`],
259    /// blocking the current thread.
260    pub fn conn_mut_blocking<F, T>(&self, func: F) -> Result<T, Error>
261    where
262        F: FnOnce(&mut Connection) -> Result<T, rusqlite::Error> + Send + 'static,
263        T: Send + 'static,
264    {
265        self.get().conn_mut_blocking(func)
266    }
267
268    /// Closes the underlying sqlite connections, blocking the current thread.
269    ///
270    /// After this method returns, all calls to `self::conn_blocking()` or
271    /// `self::conn_mut_blocking()` will return an [`Error::Closed`] error.
272    pub fn close_blocking(&self) -> Result<(), Error> {
273        let mut first_err = None;
274        for client in self.state.clients.iter() {
275            if let Err(e) = client.close_blocking() {
276                if first_err.is_none() {
277                    first_err = Some(e);
278                }
279            }
280        }
281        match first_err {
282            Some(e) => Err(e),
283            None => Ok(()),
284        }
285    }
286
287    fn get(&self) -> &Client {
288        let n = self.state.counter.fetch_add(1, Relaxed);
289        &self.state.clients[n as usize % self.state.clients.len()]
290    }
291
292    /// Runs a function on all connections in the pool asynchronously.
293    ///
294    /// The function is executed on each connection concurrently.
295    pub async fn conn_for_each<F, T>(&self, func: F) -> Vec<Result<T, Error>>
296    where
297        F: Fn(&Connection) -> Result<T, rusqlite::Error> + Send + Sync + 'static,
298        T: Send + 'static,
299    {
300        let func = Arc::new(func);
301        let futures = self.state.clients.iter().map(|client| {
302            let func = func.clone();
303            async move { client.conn(move |conn| func(conn)).await }
304        });
305        join_all(futures).await
306    }
307
308    /// Runs a function on all connections in the pool, blocking the current thread.
309    pub fn conn_for_each_blocking<F, T>(&self, func: F) -> Vec<Result<T, Error>>
310    where
311        F: Fn(&Connection) -> Result<T, rusqlite::Error> + Send + Sync + 'static,
312        T: Send + 'static,
313    {
314        let func = Arc::new(func);
315        self.state
316            .clients
317            .iter()
318            .map(|client| {
319                let func = func.clone();
320                client.conn_blocking(move |conn| func(conn))
321            })
322            .collect()
323    }
324}