1#![deny(missing_docs, missing_debug_implementations)]
5
6use std::{
7 path::{Path, PathBuf},
8 sync::Arc,
9};
10
11use async_trait::async_trait;
12use bb8::ManageConnection;
13use rusqlite::{Connection, OpenFlags};
14
15
16#[cfg(test)]
17mod tests;
18
19#[derive(Clone, Debug)]
22pub struct RusqliteConnectionManager(Arc<ConnectionOptions>);
23
24#[derive(Debug)]
25struct ConnectionOptions {
26 mode: OpenMode,
27 path: PathBuf,
28}
29
30#[derive(Debug)]
31enum OpenMode {
32 Plain,
33 WithFlags {
34 flags: rusqlite::OpenFlags,
35 },
36 WithFlagsAndVFS {
37 flags: rusqlite::OpenFlags,
38 vfs: String,
39 },
40}
41
42#[derive(thiserror::Error, Debug)]
44pub enum Error {
45 #[error("rusqlite error")]
47 Rusqlite(#[from] rusqlite::Error),
48
49 #[error("tokio join error")]
51 TokioJoin(#[from] tokio::task::JoinError),
52}
53
54impl RusqliteConnectionManager {
55 pub fn new<P: AsRef<Path>>(path: P) -> Self {
57 Self(Arc::new(ConnectionOptions {
58 mode: OpenMode::Plain,
59 path: path.as_ref().into(),
60 }))
61 }
62
63 pub fn new_with_flags<P: AsRef<Path>>(path: P, flags: OpenFlags) -> Self {
65 Self(Arc::new(ConnectionOptions {
66 mode: OpenMode::WithFlags { flags },
67 path: path.as_ref().into(),
68 }))
69 }
70
71 pub fn new_with_flags_and_vfs<P: AsRef<Path>>(path: P, flags: OpenFlags, vfs: &str) -> Self {
73 Self(Arc::new(ConnectionOptions {
74 mode: OpenMode::WithFlagsAndVFS {
75 flags,
76 vfs: vfs.into(),
77 },
78 path: path.as_ref().into(),
79 }))
80 }
81}
82
83#[async_trait]
84impl ManageConnection for RusqliteConnectionManager {
85 type Connection = Connection;
86 type Error = Error;
87
88 async fn connect(&self) -> Result<Self::Connection, Self::Error> {
89 let options = self.0.clone();
90
91 Ok(tokio::task::spawn_blocking(move || match &options.mode {
95 OpenMode::Plain => rusqlite::Connection::open(&options.path),
96 OpenMode::WithFlags { flags } => {
97 rusqlite::Connection::open_with_flags(&options.path, *flags)
98 }
99 OpenMode::WithFlagsAndVFS { flags, vfs } => {
100 rusqlite::Connection::open_with_flags_and_vfs(&options.path, *flags, vfs)
101 }
102 })
103 .await??)
104 }
105
106 async fn is_valid(
107 &self,
108 conn: &mut Self::Connection
109 ) -> Result<(), Self::Error> {
110 tokio::task::block_in_place(|| conn.execute("SELECT 1", []))?;
116 Ok(())
117 }
118
119 fn has_broken(&self, _conn: &mut Self::Connection) -> bool {
120 false
125 }
126}