1#![deny(missing_docs, missing_debug_implementations)]
5
6use std::{
7 path::{Path, PathBuf},
8 sync::Arc,
9};
10
11use async_trait::async_trait;
12use bb8::ManageConnection;
13use rusqlite::{Connection, OpenFlags, NO_PARAMS};
14
15#[cfg(test)]
16mod tests;
17
18#[derive(Clone, Debug)]
21pub struct RusqliteConnectionManager(Arc<ConnectionOptions>);
22
23#[derive(Debug)]
24struct ConnectionOptions {
25 mode: OpenMode,
26 path: PathBuf,
27}
28
29#[derive(Debug)]
30enum OpenMode {
31 Plain,
32 WithFlags {
33 flags: rusqlite::OpenFlags,
34 },
35 WithFlagsAndVFS {
36 flags: rusqlite::OpenFlags,
37 vfs: String,
38 },
39}
40
41#[derive(thiserror::Error, Debug)]
43pub enum Error {
44 #[error("rusqlite error")]
46 Rusqlite(#[from] rusqlite::Error),
47
48 #[error("tokio join error")]
50 TokioJoin(#[from] tokio::task::JoinError),
51}
52
53impl RusqliteConnectionManager {
54 pub fn new<P>(path: P) -> Self
56 where
57 P: AsRef<Path>,
58 {
59 Self(Arc::new(ConnectionOptions {
60 mode: OpenMode::Plain,
61 path: path.as_ref().into(),
62 }))
63 }
64
65 pub fn new_with_flags<P>(path: P, flags: OpenFlags) -> Self
67 where
68 P: AsRef<Path>,
69 {
70 Self(Arc::new(ConnectionOptions {
71 mode: OpenMode::WithFlags { flags },
72 path: path.as_ref().into(),
73 }))
74 }
75
76 pub fn new_with_flags_and_vfs<P>(path: P, flags: OpenFlags, vfs: &str) -> Self
78 where
79 P: AsRef<Path>,
80 {
81 Self(Arc::new(ConnectionOptions {
82 mode: OpenMode::WithFlagsAndVFS {
83 flags,
84 vfs: vfs.into(),
85 },
86 path: path.as_ref().into(),
87 }))
88 }
89}
90
91#[async_trait]
92impl ManageConnection for RusqliteConnectionManager {
93 type Connection = Connection;
94 type Error = Error;
95
96 async fn connect(&self) -> Result<Self::Connection, Self::Error> {
97 let options = self.0.clone();
98
99 Ok(tokio::task::spawn_blocking(move || match &options.mode {
103 OpenMode::Plain => rusqlite::Connection::open(&options.path),
104 OpenMode::WithFlags { flags } => {
105 rusqlite::Connection::open_with_flags(&options.path, *flags)
106 }
107 OpenMode::WithFlagsAndVFS { flags, vfs } => {
108 rusqlite::Connection::open_with_flags_and_vfs(&options.path, *flags, &vfs)
109 }
110 })
111 .await??)
112 }
113
114 async fn is_valid(
115 &self,
116 conn: &mut bb8::PooledConnection<'_, Self>,
117 ) -> Result<(), Self::Error> {
118 tokio::task::block_in_place(|| conn.execute("SELECT 1", NO_PARAMS))?;
124 Ok(())
125 }
126
127 fn has_broken(&self, _conn: &mut Self::Connection) -> bool {
128 false
133 }
134}