Skip to main content

async_sqlite/
pool.rs

1use std::{
2    num::NonZeroUsize,
3    path::{Path, PathBuf},
4    sync::{
5        atomic::{AtomicU64, Ordering::Relaxed},
6        Arc,
7    },
8    thread::available_parallelism,
9};
10
11use crate::{Client, ClientBuilder, Error, JournalMode};
12
13use futures_util::future::join_all;
14use rusqlite::{Connection, OpenFlags};
15
16/// A `PoolBuilder` can be used to create a [`Pool`] with custom
17/// configuration.
18///
19/// See [`Client`] for more information.
20///
21/// # Examples
22///
23/// ```rust
24/// # use async_sqlite::PoolBuilder;
25/// # async fn run() -> Result<(), async_sqlite::Error> {
26/// let pool = PoolBuilder::new().path("path/to/db.sqlite3").open().await?;
27///
28/// // ...
29///
30/// pool.close().await?;
31/// # Ok(())
32/// # }
33/// ```
34#[derive(Clone, Debug, Default)]
35pub struct PoolBuilder {
36    path: Option<PathBuf>,
37    shared_memory_name: Option<String>,
38    flags: OpenFlags,
39    journal_mode: Option<JournalMode>,
40    vfs: Option<String>,
41    num_conns: Option<usize>,
42    queue_capacity: Option<usize>,
43}
44
45impl PoolBuilder {
46    /// Returns a new [`PoolBuilder`] with the default settings.
47    pub fn new() -> Self {
48        Self::default()
49    }
50
51    /// Specify the path of the sqlite3 database to open.
52    ///
53    /// By default, an in-memory database is used.
54    pub fn path<P: AsRef<Path>>(mut self, path: P) -> Self {
55        self.path = Some(path.as_ref().into());
56        self.shared_memory_name = None;
57        self
58    }
59
60    /// Use a named shared in-memory sqlite database.
61    ///
62    /// This opens connections with a URI of the form
63    /// `file:<name>?mode=memory&cache=shared` and enables
64    /// [`OpenFlags::SQLITE_OPEN_URI`] and
65    /// [`OpenFlags::SQLITE_OPEN_SHARED_CACHE`].
66    ///
67    /// SQLite shared-cache mode has caveats and is discouraged by SQLite for
68    /// many workloads. Prefer a file-backed database when possible. The
69    /// in-memory database is deleted after the last connection using this name
70    /// is closed.
71    ///
72    /// ```
73    /// use async_sqlite::PoolBuilder;
74    ///
75    /// let builder = PoolBuilder::new().shared_memory("my-pool").num_conns(2);
76    /// ```
77    pub fn shared_memory<N: AsRef<str>>(mut self, name: N) -> Self {
78        self.path = None;
79        self.shared_memory_name = Some(name.as_ref().to_owned());
80        self
81    }
82
83    /// Specify the [`OpenFlags`] to use when opening a new connection.
84    ///
85    /// By default, [`OpenFlags::default()`] is used.
86    pub fn flags(mut self, flags: OpenFlags) -> Self {
87        self.flags = flags;
88        self
89    }
90
91    /// Specify the [`JournalMode`] to set when opening a new connection.
92    ///
93    /// By default, no `journal_mode` is explicity set.
94    pub fn journal_mode(mut self, journal_mode: JournalMode) -> Self {
95        self.journal_mode = Some(journal_mode);
96        self
97    }
98
99    /// Specify the name of the [vfs](https://www.sqlite.org/vfs.html) to use.
100    pub fn vfs(mut self, vfs: &str) -> Self {
101        self.vfs = Some(vfs.to_owned());
102        self
103    }
104
105    /// Specify the number of sqlite connections to open as part of the pool.
106    ///
107    /// File-backed and shared-memory pools default to the number of logical
108    /// CPUs of the current system. Anonymous in-memory pools, including
109    /// `path(":memory:")`, default to `1` connection because each sqlite
110    /// `:memory:` connection is a separate database. Values less than `1` are
111    /// clamped to `1`.
112    ///
113    /// ```
114    /// use async_sqlite::PoolBuilder;
115    ///
116    /// let builder = PoolBuilder::new().num_conns(2);
117    /// ```
118    pub fn num_conns(mut self, num_conns: usize) -> Self {
119        self.num_conns = Some(num_conns.max(1));
120        self
121    }
122
123    /// Limit the number of commands that may wait in each connection's worker
124    /// queue.
125    ///
126    /// By default, each queue is unbounded. If a capacity is configured, calls
127    /// return [`Error::QueueFull`] when a selected connection already has that
128    /// many commands waiting for its worker thread. A capacity of `0` allows a
129    /// command to be accepted only when the worker is ready to receive it
130    /// immediately.
131    pub fn queue_capacity(mut self, queue_capacity: usize) -> Self {
132        self.queue_capacity = Some(queue_capacity);
133        self
134    }
135
136    /// Returns a new [`Pool`] that uses the `PoolBuilder` configuration.
137    ///
138    /// # Examples
139    ///
140    /// ```rust
141    /// # use async_sqlite::PoolBuilder;
142    /// # async fn run() -> Result<(), async_sqlite::Error> {
143    /// let pool = PoolBuilder::new().open().await?;
144    /// # Ok(())
145    /// # }
146    /// ```
147    pub async fn open(self) -> Result<Pool, Error> {
148        let num_conns = self.get_num_conns();
149        self.validate(num_conns)?;
150
151        // Open the first connection with full config (including journal_mode).
152        // This must complete before opening remaining connections to avoid
153        // concurrent PRAGMA writes on a new database file.
154        let first = self.client_builder().open().await?;
155
156        // Open remaining connections with journal_mode too, so connection-local
157        // modes are applied consistently across the pool.
158        let opens = (1..num_conns).map(|_| self.client_builder().open());
159        let mut clients = vec![first];
160        clients.extend(
161            join_all(opens)
162                .await
163                .into_iter()
164                .collect::<Result<Vec<Client>, Error>>()?,
165        );
166
167        Ok(Pool {
168            state: Arc::new(State {
169                clients,
170                counter: AtomicU64::new(0),
171            }),
172        })
173    }
174
175    /// Returns a new [`Pool`] that uses the `PoolBuilder` configuration,
176    /// blocking the current thread.
177    ///
178    /// # Examples
179    ///
180    /// ```rust
181    /// # use async_sqlite::PoolBuilder;
182    /// # fn run() -> Result<(), async_sqlite::Error> {
183    /// let pool = PoolBuilder::new().open_blocking()?;
184    /// # Ok(())
185    /// # }
186    /// ```
187    pub fn open_blocking(self) -> Result<Pool, Error> {
188        let num_conns = self.get_num_conns();
189        self.validate(num_conns)?;
190
191        // Open the first connection with full config (including journal_mode).
192        let first = self.client_builder().open_blocking()?;
193
194        // Open remaining connections with journal_mode too, so connection-local
195        // modes are applied consistently across the pool.
196        let mut clients = vec![first];
197        clients.extend(
198            (1..num_conns)
199                .map(|_| self.client_builder().open_blocking())
200                .collect::<Result<Vec<Client>, Error>>()?,
201        );
202
203        Ok(Pool {
204            state: Arc::new(State {
205                clients,
206                counter: AtomicU64::new(0),
207            }),
208        })
209    }
210
211    fn get_num_conns(&self) -> usize {
212        if let Some(num_conns) = self.num_conns {
213            return num_conns;
214        }
215
216        if self.is_anonymous_memory() {
217            return 1;
218        }
219
220        available_parallelism()
221            .unwrap_or_else(|_| NonZeroUsize::new(1).unwrap())
222            .into()
223    }
224
225    fn validate(&self, num_conns: usize) -> Result<(), Error> {
226        if self
227            .shared_memory_name
228            .as_ref()
229            .is_some_and(|name| name.is_empty())
230        {
231            return Err(Error::Config {
232                message: "shared memory database name must not be empty",
233            });
234        }
235
236        if self.is_anonymous_memory() && num_conns > 1 {
237            return Err(Error::Config {
238                message: "anonymous in-memory pools cannot use multiple connections; call path(...) for file-backed pools or shared_memory(...) for named shared in-memory pools",
239            });
240        }
241
242        Ok(())
243    }
244
245    fn client_builder(&self) -> ClientBuilder {
246        ClientBuilder {
247            path: self.connection_path(),
248            flags: self.connection_flags(),
249            journal_mode: self.journal_mode,
250            vfs: self.vfs.clone(),
251            queue_capacity: self.queue_capacity,
252        }
253    }
254
255    fn connection_path(&self) -> Option<PathBuf> {
256        self.shared_memory_name
257            .as_deref()
258            .map(shared_memory_uri)
259            .or_else(|| self.path.clone())
260    }
261
262    fn connection_flags(&self) -> OpenFlags {
263        let mut flags = self.flags;
264        if self.shared_memory_name.is_some() {
265            flags.insert(OpenFlags::SQLITE_OPEN_URI);
266            flags.insert(OpenFlags::SQLITE_OPEN_SHARED_CACHE);
267            flags.remove(OpenFlags::SQLITE_OPEN_PRIVATE_CACHE);
268        }
269        flags
270    }
271
272    fn is_anonymous_memory(&self) -> bool {
273        self.shared_memory_name.is_none()
274            && self
275                .path
276                .as_deref()
277                .is_none_or(|path| path == Path::new(":memory:"))
278    }
279}
280
281fn shared_memory_uri(name: &str) -> PathBuf {
282    let mut uri = String::from("file:");
283    push_uri_encoded(name, &mut uri);
284    uri.push_str("?mode=memory&cache=shared");
285    uri.into()
286}
287
288fn push_uri_encoded(input: &str, out: &mut String) {
289    const HEX: &[u8; 16] = b"0123456789ABCDEF";
290
291    for byte in input.bytes() {
292        match byte {
293            b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'.' | b'_' | b'~' => {
294                out.push(byte.into());
295            }
296            _ => {
297                out.push('%');
298                out.push(HEX[(byte >> 4) as usize].into());
299                out.push(HEX[(byte & 0x0F) as usize].into());
300            }
301        }
302    }
303}
304
305/// A simple Pool of sqlite connections.
306///
307/// A Pool has the same API as an individual [`Client`].
308#[derive(Clone)]
309pub struct Pool {
310    state: Arc<State>,
311}
312
313struct State {
314    clients: Vec<Client>,
315    counter: AtomicU64,
316}
317
318impl Pool {
319    /// Invokes the provided function with a [`rusqlite::Connection`].
320    pub async fn conn<F, T>(&self, func: F) -> Result<T, Error>
321    where
322        F: FnOnce(&Connection) -> Result<T, rusqlite::Error> + Send + 'static,
323        T: Send + 'static,
324    {
325        self.get().conn(func).await
326    }
327
328    /// Invokes the provided function with a mutable [`rusqlite::Connection`].
329    pub async fn conn_mut<F, T>(&self, func: F) -> Result<T, Error>
330    where
331        F: FnOnce(&mut Connection) -> Result<T, rusqlite::Error> + Send + 'static,
332        T: Send + 'static,
333    {
334        self.get().conn_mut(func).await
335    }
336
337    /// Invokes the provided function with a [`rusqlite::Connection`].
338    ///
339    /// Maps the result error type to a custom error; designed to be
340    /// used in conjunction with [`query_and_then`](https://docs.rs/rusqlite/latest/rusqlite/struct.CachedStatement.html#method.query_and_then).
341    pub async fn conn_and_then<F, T, E>(&self, func: F) -> Result<T, E>
342    where
343        F: FnOnce(&Connection) -> Result<T, E> + Send + 'static,
344        T: Send + 'static,
345        E: From<rusqlite::Error> + From<Error> + Send + 'static,
346    {
347        self.get().conn_and_then(func).await
348    }
349
350    /// Invokes the provided function with a mutable [`rusqlite::Connection`].
351    ///
352    /// Maps the result error type to a custom error; designed to be
353    /// used in conjunction with [`query_and_then`](https://docs.rs/rusqlite/latest/rusqlite/struct.CachedStatement.html#method.query_and_then).
354    pub async fn conn_mut_and_then<F, T, E>(&self, func: F) -> Result<T, E>
355    where
356        F: FnOnce(&mut Connection) -> Result<T, E> + Send + 'static,
357        T: Send + 'static,
358        E: From<rusqlite::Error> + From<Error> + Send + 'static,
359    {
360        self.get().conn_mut_and_then(func).await
361    }
362
363    /// Closes the underlying sqlite connections.
364    ///
365    /// After this method returns, all calls to `self::conn()` or
366    /// `self::conn_mut()` will return an [`Error::Closed`] error.
367    pub async fn close(&self) -> Result<(), Error> {
368        let closes = self.state.clients.iter().map(|client| client.close());
369        let res = join_all(closes).await;
370        res.into_iter().collect::<Result<Vec<_>, Error>>()?;
371        Ok(())
372    }
373
374    /// Invokes the provided function with a [`rusqlite::Connection`], blocking
375    /// the current thread.
376    pub fn conn_blocking<F, T>(&self, func: F) -> Result<T, Error>
377    where
378        F: FnOnce(&Connection) -> Result<T, rusqlite::Error> + Send + 'static,
379        T: Send + 'static,
380    {
381        self.get().conn_blocking(func)
382    }
383
384    /// Invokes the provided function with a mutable [`rusqlite::Connection`],
385    /// blocking the current thread.
386    pub fn conn_mut_blocking<F, T>(&self, func: F) -> Result<T, Error>
387    where
388        F: FnOnce(&mut Connection) -> Result<T, rusqlite::Error> + Send + 'static,
389        T: Send + 'static,
390    {
391        self.get().conn_mut_blocking(func)
392    }
393
394    /// Invokes the provided function with a [`rusqlite::Connection`], blocking
395    /// the current thread.
396    ///
397    /// Maps the result error type to a custom error; designed to be
398    /// used in conjunction with [`query_and_then`](https://docs.rs/rusqlite/latest/rusqlite/struct.CachedStatement.html#method.query_and_then).
399    pub fn conn_and_then_blocking<F, T, E>(&self, func: F) -> Result<T, E>
400    where
401        F: FnOnce(&Connection) -> Result<T, E> + Send + 'static,
402        T: Send + 'static,
403        E: From<rusqlite::Error> + From<Error> + Send + 'static,
404    {
405        self.get().conn_and_then_blocking(func)
406    }
407
408    /// Invokes the provided function with a mutable [`rusqlite::Connection`],
409    /// blocking the current thread.
410    ///
411    /// Maps the result error type to a custom error; designed to be
412    /// used in conjunction with [`query_and_then`](https://docs.rs/rusqlite/latest/rusqlite/struct.CachedStatement.html#method.query_and_then).
413    pub fn conn_mut_and_then_blocking<F, T, E>(&self, func: F) -> Result<T, E>
414    where
415        F: FnOnce(&mut Connection) -> Result<T, E> + Send + 'static,
416        T: Send + 'static,
417        E: From<rusqlite::Error> + From<Error> + Send + 'static,
418    {
419        self.get().conn_mut_and_then_blocking(func)
420    }
421
422    /// Closes the underlying sqlite connections, blocking the current thread.
423    ///
424    /// After this method returns, all calls to `self::conn_blocking()` or
425    /// `self::conn_mut_blocking()` will return an [`Error::Closed`] error.
426    pub fn close_blocking(&self) -> Result<(), Error> {
427        let mut first_err = None;
428        for client in self.state.clients.iter() {
429            if let Err(e) = client.close_blocking() {
430                if first_err.is_none() {
431                    first_err = Some(e);
432                }
433            }
434        }
435        match first_err {
436            Some(e) => Err(e),
437            None => Ok(()),
438        }
439    }
440
441    fn get(&self) -> &Client {
442        let n = self.state.counter.fetch_add(1, Relaxed);
443        &self.state.clients[n as usize % self.state.clients.len()]
444    }
445
446    /// Runs a function on all connections in the pool asynchronously.
447    ///
448    /// The function is executed on each connection concurrently.
449    pub async fn conn_for_each<F, T>(&self, func: F) -> Vec<Result<T, Error>>
450    where
451        F: Fn(&Connection) -> Result<T, rusqlite::Error> + Send + Sync + 'static,
452        T: Send + 'static,
453    {
454        let func = Arc::new(func);
455        let futures = self.state.clients.iter().map(|client| {
456            let func = func.clone();
457            async move { client.conn(move |conn| func(conn)).await }
458        });
459        join_all(futures).await
460    }
461
462    /// Runs a function on all connections in the pool, blocking the current thread.
463    pub fn conn_for_each_blocking<F, T>(&self, func: F) -> Vec<Result<T, Error>>
464    where
465        F: Fn(&Connection) -> Result<T, rusqlite::Error> + Send + Sync + 'static,
466        T: Send + 'static,
467    {
468        let func = Arc::new(func);
469        self.state
470            .clients
471            .iter()
472            .map(|client| {
473                let func = func.clone();
474                client.conn_blocking(move |conn| func(conn))
475            })
476            .collect()
477    }
478}