1use std::{
2 num::NonZeroUsize,
3 path::{Path, PathBuf},
4 sync::{
5 atomic::{AtomicU64, Ordering::Relaxed},
6 Arc,
7 },
8 thread::available_parallelism,
9};
10
11use crate::{Client, ClientBuilder, Error, JournalMode};
12
13use futures_util::future::join_all;
14use rusqlite::{Connection, OpenFlags};
15
16#[derive(Clone, Debug, Default)]
35pub struct PoolBuilder {
36 path: Option<PathBuf>,
37 shared_memory_name: Option<String>,
38 flags: OpenFlags,
39 journal_mode: Option<JournalMode>,
40 vfs: Option<String>,
41 num_conns: Option<usize>,
42 queue_capacity: Option<usize>,
43}
44
45impl PoolBuilder {
46 pub fn new() -> Self {
48 Self::default()
49 }
50
51 pub fn path<P: AsRef<Path>>(mut self, path: P) -> Self {
55 self.path = Some(path.as_ref().into());
56 self.shared_memory_name = None;
57 self
58 }
59
60 pub fn shared_memory<N: AsRef<str>>(mut self, name: N) -> Self {
78 self.path = None;
79 self.shared_memory_name = Some(name.as_ref().to_owned());
80 self
81 }
82
83 pub fn flags(mut self, flags: OpenFlags) -> Self {
87 self.flags = flags;
88 self
89 }
90
91 pub fn journal_mode(mut self, journal_mode: JournalMode) -> Self {
95 self.journal_mode = Some(journal_mode);
96 self
97 }
98
99 pub fn vfs(mut self, vfs: &str) -> Self {
101 self.vfs = Some(vfs.to_owned());
102 self
103 }
104
105 pub fn num_conns(mut self, num_conns: usize) -> Self {
119 self.num_conns = Some(num_conns.max(1));
120 self
121 }
122
123 pub fn queue_capacity(mut self, queue_capacity: usize) -> Self {
132 self.queue_capacity = Some(queue_capacity);
133 self
134 }
135
136 pub async fn open(self) -> Result<Pool, Error> {
148 let num_conns = self.get_num_conns();
149 self.validate(num_conns)?;
150
151 let first = self.client_builder().open().await?;
155
156 let opens = (1..num_conns).map(|_| self.client_builder().open());
159 let mut clients = vec![first];
160 clients.extend(
161 join_all(opens)
162 .await
163 .into_iter()
164 .collect::<Result<Vec<Client>, Error>>()?,
165 );
166
167 Ok(Pool {
168 state: Arc::new(State {
169 clients,
170 counter: AtomicU64::new(0),
171 }),
172 })
173 }
174
175 pub fn open_blocking(self) -> Result<Pool, Error> {
188 let num_conns = self.get_num_conns();
189 self.validate(num_conns)?;
190
191 let first = self.client_builder().open_blocking()?;
193
194 let mut clients = vec![first];
197 clients.extend(
198 (1..num_conns)
199 .map(|_| self.client_builder().open_blocking())
200 .collect::<Result<Vec<Client>, Error>>()?,
201 );
202
203 Ok(Pool {
204 state: Arc::new(State {
205 clients,
206 counter: AtomicU64::new(0),
207 }),
208 })
209 }
210
211 fn get_num_conns(&self) -> usize {
212 if let Some(num_conns) = self.num_conns {
213 return num_conns;
214 }
215
216 if self.is_anonymous_memory() {
217 return 1;
218 }
219
220 available_parallelism()
221 .unwrap_or_else(|_| NonZeroUsize::new(1).unwrap())
222 .into()
223 }
224
225 fn validate(&self, num_conns: usize) -> Result<(), Error> {
226 if self
227 .shared_memory_name
228 .as_ref()
229 .is_some_and(|name| name.is_empty())
230 {
231 return Err(Error::Config {
232 message: "shared memory database name must not be empty",
233 });
234 }
235
236 if self.is_anonymous_memory() && num_conns > 1 {
237 return Err(Error::Config {
238 message: "anonymous in-memory pools cannot use multiple connections; call path(...) for file-backed pools or shared_memory(...) for named shared in-memory pools",
239 });
240 }
241
242 Ok(())
243 }
244
245 fn client_builder(&self) -> ClientBuilder {
246 ClientBuilder {
247 path: self.connection_path(),
248 flags: self.connection_flags(),
249 journal_mode: self.journal_mode,
250 vfs: self.vfs.clone(),
251 queue_capacity: self.queue_capacity,
252 }
253 }
254
255 fn connection_path(&self) -> Option<PathBuf> {
256 self.shared_memory_name
257 .as_deref()
258 .map(shared_memory_uri)
259 .or_else(|| self.path.clone())
260 }
261
262 fn connection_flags(&self) -> OpenFlags {
263 let mut flags = self.flags;
264 if self.shared_memory_name.is_some() {
265 flags.insert(OpenFlags::SQLITE_OPEN_URI);
266 flags.insert(OpenFlags::SQLITE_OPEN_SHARED_CACHE);
267 flags.remove(OpenFlags::SQLITE_OPEN_PRIVATE_CACHE);
268 }
269 flags
270 }
271
272 fn is_anonymous_memory(&self) -> bool {
273 self.shared_memory_name.is_none()
274 && self
275 .path
276 .as_deref()
277 .is_none_or(|path| path == Path::new(":memory:"))
278 }
279}
280
281fn shared_memory_uri(name: &str) -> PathBuf {
282 let mut uri = String::from("file:");
283 push_uri_encoded(name, &mut uri);
284 uri.push_str("?mode=memory&cache=shared");
285 uri.into()
286}
287
288fn push_uri_encoded(input: &str, out: &mut String) {
289 const HEX: &[u8; 16] = b"0123456789ABCDEF";
290
291 for byte in input.bytes() {
292 match byte {
293 b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'.' | b'_' | b'~' => {
294 out.push(byte.into());
295 }
296 _ => {
297 out.push('%');
298 out.push(HEX[(byte >> 4) as usize].into());
299 out.push(HEX[(byte & 0x0F) as usize].into());
300 }
301 }
302 }
303}
304
305#[derive(Clone)]
309pub struct Pool {
310 state: Arc<State>,
311}
312
313struct State {
314 clients: Vec<Client>,
315 counter: AtomicU64,
316}
317
318impl Pool {
319 pub async fn conn<F, T>(&self, func: F) -> Result<T, Error>
321 where
322 F: FnOnce(&Connection) -> Result<T, rusqlite::Error> + Send + 'static,
323 T: Send + 'static,
324 {
325 self.get().conn(func).await
326 }
327
328 pub async fn conn_mut<F, T>(&self, func: F) -> Result<T, Error>
330 where
331 F: FnOnce(&mut Connection) -> Result<T, rusqlite::Error> + Send + 'static,
332 T: Send + 'static,
333 {
334 self.get().conn_mut(func).await
335 }
336
337 pub async fn conn_and_then<F, T, E>(&self, func: F) -> Result<T, E>
342 where
343 F: FnOnce(&Connection) -> Result<T, E> + Send + 'static,
344 T: Send + 'static,
345 E: From<rusqlite::Error> + From<Error> + Send + 'static,
346 {
347 self.get().conn_and_then(func).await
348 }
349
350 pub async fn conn_mut_and_then<F, T, E>(&self, func: F) -> Result<T, E>
355 where
356 F: FnOnce(&mut Connection) -> Result<T, E> + Send + 'static,
357 T: Send + 'static,
358 E: From<rusqlite::Error> + From<Error> + Send + 'static,
359 {
360 self.get().conn_mut_and_then(func).await
361 }
362
363 pub async fn close(&self) -> Result<(), Error> {
368 let closes = self.state.clients.iter().map(|client| client.close());
369 let res = join_all(closes).await;
370 res.into_iter().collect::<Result<Vec<_>, Error>>()?;
371 Ok(())
372 }
373
374 pub fn conn_blocking<F, T>(&self, func: F) -> Result<T, Error>
377 where
378 F: FnOnce(&Connection) -> Result<T, rusqlite::Error> + Send + 'static,
379 T: Send + 'static,
380 {
381 self.get().conn_blocking(func)
382 }
383
384 pub fn conn_mut_blocking<F, T>(&self, func: F) -> Result<T, Error>
387 where
388 F: FnOnce(&mut Connection) -> Result<T, rusqlite::Error> + Send + 'static,
389 T: Send + 'static,
390 {
391 self.get().conn_mut_blocking(func)
392 }
393
394 pub fn conn_and_then_blocking<F, T, E>(&self, func: F) -> Result<T, E>
400 where
401 F: FnOnce(&Connection) -> Result<T, E> + Send + 'static,
402 T: Send + 'static,
403 E: From<rusqlite::Error> + From<Error> + Send + 'static,
404 {
405 self.get().conn_and_then_blocking(func)
406 }
407
408 pub fn conn_mut_and_then_blocking<F, T, E>(&self, func: F) -> Result<T, E>
414 where
415 F: FnOnce(&mut Connection) -> Result<T, E> + Send + 'static,
416 T: Send + 'static,
417 E: From<rusqlite::Error> + From<Error> + Send + 'static,
418 {
419 self.get().conn_mut_and_then_blocking(func)
420 }
421
422 pub fn close_blocking(&self) -> Result<(), Error> {
427 let mut first_err = None;
428 for client in self.state.clients.iter() {
429 if let Err(e) = client.close_blocking() {
430 if first_err.is_none() {
431 first_err = Some(e);
432 }
433 }
434 }
435 match first_err {
436 Some(e) => Err(e),
437 None => Ok(()),
438 }
439 }
440
441 fn get(&self) -> &Client {
442 let n = self.state.counter.fetch_add(1, Relaxed);
443 &self.state.clients[n as usize % self.state.clients.len()]
444 }
445
446 pub async fn conn_for_each<F, T>(&self, func: F) -> Vec<Result<T, Error>>
450 where
451 F: Fn(&Connection) -> Result<T, rusqlite::Error> + Send + Sync + 'static,
452 T: Send + 'static,
453 {
454 let func = Arc::new(func);
455 let futures = self.state.clients.iter().map(|client| {
456 let func = func.clone();
457 async move { client.conn(move |conn| func(conn)).await }
458 });
459 join_all(futures).await
460 }
461
462 pub fn conn_for_each_blocking<F, T>(&self, func: F) -> Vec<Result<T, Error>>
464 where
465 F: Fn(&Connection) -> Result<T, rusqlite::Error> + Send + Sync + 'static,
466 T: Send + 'static,
467 {
468 let func = Arc::new(func);
469 self.state
470 .clients
471 .iter()
472 .map(|client| {
473 let func = func.clone();
474 client.conn_blocking(move |conn| func(conn))
475 })
476 .collect()
477 }
478}