db_derive/
pool.rs

1use {
2    crate::{table::Schema, Error, Transaction},
3    std::path::Path,
4};
5
6#[derive(Clone, Debug)]
7pub enum Pool {
8    #[cfg(feature = "postgresql")]
9    PostgreSQL(r2d2::Pool<self::postgres::PostgresConnectionManager<::postgres::NoTls>>),
10    #[cfg(feature = "sqlite")]
11    SQLite(r2d2::Pool<self::sqlite::SqliteConnectionManager>),
12}
13
14impl Pool {
15    #[cfg(feature = "postgresql")]
16    pub fn postgres(config: ::postgres::Config) -> Result<Self, Error> {
17        let conn = Pool::PostgreSQL(r2d2::Pool::new(
18            crate::pool::postgres::PostgresConnectionManager::new(config, ::postgres::NoTls),
19        )?);
20
21        Ok(conn)
22    }
23
24    #[cfg(feature = "sqlite")]
25    pub fn sqlite(path: impl AsRef<Path>) -> Result<Self, Error> {
26        let conn = Pool::SQLite(r2d2::Pool::new(
27            crate::pool::sqlite::SqliteConnectionManager::file(path),
28        )?);
29
30        Ok(conn)
31    }
32
33    pub fn as_kind(&self) -> PoolKind {
34        match self {
35            #[cfg(feature = "postgresql")]
36            Pool::PostgreSQL(_) => PoolKind::PostgreSQL,
37            #[cfg(feature = "sqlite")]
38            Pool::SQLite(_) => PoolKind::SQLite,
39        }
40    }
41
42    pub fn batch_execute(&self, exec: impl AsRef<str>) -> Result<(), Error> {
43        match self {
44            #[cfg(feature = "postgresql")]
45            Pool::PostgreSQL(pool) => {
46                let mut conn = pool.get()?;
47
48                conn.batch_execute(exec.as_ref())?;
49            }
50            #[cfg(feature = "sqlite")]
51            Pool::SQLite(pool) => {
52                let conn = pool.get()?;
53
54                conn.execute_batch(exec.as_ref())?;
55            }
56        }
57
58        Ok(())
59    }
60
61    pub fn transaction(
62        &self,
63        run: impl FnOnce(Transaction<'_>) -> Result<(), Error>,
64    ) -> Result<(), Error> {
65        match self {
66            #[cfg(feature = "postgresql")]
67            Pool::PostgreSQL(pool) => {
68                let mut conn = pool.get()?;
69
70                let trans = conn.transaction()?;
71
72                let inner = Transaction::PostgreSQL(trans);
73
74                run(inner)?;
75            }
76            #[cfg(feature = "sqlite")]
77            Pool::SQLite(pool) => {
78                let mut conn = pool.get()?;
79
80                let trans = conn.transaction()?;
81
82                let inner = Transaction::SQLite(trans);
83
84                run(inner)?;
85            }
86        }
87
88        Ok(())
89    }
90
91    pub fn schema<T: Schema>(&self) -> Result<(), Error> {
92        match self {
93            #[cfg(feature = "postgresql")]
94            Pool::PostgreSQL(pool) => {
95                let mut conn = pool.get()?;
96
97                conn.batch_execute(T::schema_postgres())?;
98            }
99            #[cfg(feature = "sqlite")]
100            Pool::SQLite(pool) => {
101                let conn = pool.get()?;
102
103                conn.execute_batch(T::schema_sqlite())?;
104            }
105        }
106
107        Ok(())
108    }
109}
110
111#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
112pub enum PoolKind {
113    #[cfg(feature = "postgresql")]
114    PostgreSQL,
115    #[cfg(feature = "sqlite")]
116    SQLite,
117}
118
119impl From<Pool> for PoolKind {
120    fn from(pool: Pool) -> PoolKind {
121        pool.as_kind()
122    }
123}
124
125impl<'a> From<&'a Pool> for PoolKind {
126    fn from(pool: &'a Pool) -> PoolKind {
127        pool.as_kind()
128    }
129}
130
131#[cfg(feature = "postgresql")]
132pub mod postgres {
133    use {
134        postgres::{
135            tls::{MakeTlsConnect, TlsConnect},
136            Client, Config, Error, Socket,
137        },
138        r2d2::ManageConnection,
139    };
140
141    #[derive(Debug)]
142    pub struct PostgresConnectionManager<T> {
143        config: Config,
144        tls_connector: T,
145    }
146
147    impl<T> PostgresConnectionManager<T>
148    where
149        T: MakeTlsConnect<Socket> + Clone + 'static + Sync + Send,
150        T::TlsConnect: Send,
151        T::Stream: Send,
152        <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
153    {
154        /// Creates a new `PostgresConnectionManager`.
155        pub fn new(config: Config, tls_connector: T) -> PostgresConnectionManager<T> {
156            PostgresConnectionManager {
157                config,
158                tls_connector,
159            }
160        }
161    }
162
163    impl<T> ManageConnection for PostgresConnectionManager<T>
164    where
165        T: MakeTlsConnect<Socket> + Clone + 'static + Sync + Send,
166        T::TlsConnect: Send,
167        T::Stream: Send,
168        <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
169    {
170        type Connection = Client;
171        type Error = Error;
172
173        fn connect(&self) -> Result<Client, Error> {
174            self.config.connect(self.tls_connector.clone())
175        }
176
177        fn is_valid(&self, client: &mut Client) -> Result<(), Error> {
178            client.simple_query("").map(|_| ())
179        }
180
181        fn has_broken(&self, client: &mut Client) -> bool {
182            client.is_closed()
183        }
184    }
185}
186
187#[cfg(feature = "sqlite")]
188pub mod sqlite {
189    use {
190        rusqlite::{Connection, Error, OpenFlags},
191        std::{
192            fmt,
193            path::{Path, PathBuf},
194        },
195    };
196
197    pub struct SqliteConnectionManager {
198        path: PathBuf,
199    }
200
201    impl fmt::Debug for SqliteConnectionManager {
202        fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
203            let mut builder = f.debug_struct("SqliteConnectionManager");
204            let _ = builder.field("path", &self.path);
205            builder.finish()
206        }
207    }
208
209    impl SqliteConnectionManager {
210        pub fn file<P: AsRef<Path>>(path: P) -> Self {
211            Self {
212                path: path.as_ref().to_path_buf(),
213            }
214        }
215    }
216
217    impl r2d2::ManageConnection for SqliteConnectionManager {
218        type Connection = Connection;
219        type Error = rusqlite::Error;
220
221        fn connect(&self) -> Result<Connection, Error> {
222            Connection::open_with_flags(&self.path, OpenFlags::default()).map_err(Into::into)
223        }
224
225        fn is_valid(&self, conn: &mut Connection) -> Result<(), Error> {
226            conn.execute_batch("").map_err(Into::into)
227        }
228
229        fn has_broken(&self, _: &mut Connection) -> bool {
230            false
231        }
232    }
233}