1use {
2 crate::{table::Schema, Error, Transaction},
3 std::path::Path,
4};
5
6#[derive(Clone, Debug)]
7pub enum Pool {
8 #[cfg(feature = "postgresql")]
9 PostgreSQL(r2d2::Pool<self::postgres::PostgresConnectionManager<::postgres::NoTls>>),
10 #[cfg(feature = "sqlite")]
11 SQLite(r2d2::Pool<self::sqlite::SqliteConnectionManager>),
12}
13
14impl Pool {
15 #[cfg(feature = "postgresql")]
16 pub fn postgres(config: ::postgres::Config) -> Result<Self, Error> {
17 let conn = Pool::PostgreSQL(r2d2::Pool::new(
18 crate::pool::postgres::PostgresConnectionManager::new(config, ::postgres::NoTls),
19 )?);
20
21 Ok(conn)
22 }
23
24 #[cfg(feature = "sqlite")]
25 pub fn sqlite(path: impl AsRef<Path>) -> Result<Self, Error> {
26 let conn = Pool::SQLite(r2d2::Pool::new(
27 crate::pool::sqlite::SqliteConnectionManager::file(path),
28 )?);
29
30 Ok(conn)
31 }
32
33 pub fn as_kind(&self) -> PoolKind {
34 match self {
35 #[cfg(feature = "postgresql")]
36 Pool::PostgreSQL(_) => PoolKind::PostgreSQL,
37 #[cfg(feature = "sqlite")]
38 Pool::SQLite(_) => PoolKind::SQLite,
39 }
40 }
41
42 pub fn batch_execute(&self, exec: impl AsRef<str>) -> Result<(), Error> {
43 match self {
44 #[cfg(feature = "postgresql")]
45 Pool::PostgreSQL(pool) => {
46 let mut conn = pool.get()?;
47
48 conn.batch_execute(exec.as_ref())?;
49 }
50 #[cfg(feature = "sqlite")]
51 Pool::SQLite(pool) => {
52 let conn = pool.get()?;
53
54 conn.execute_batch(exec.as_ref())?;
55 }
56 }
57
58 Ok(())
59 }
60
61 pub fn transaction(
62 &self,
63 run: impl FnOnce(Transaction<'_>) -> Result<(), Error>,
64 ) -> Result<(), Error> {
65 match self {
66 #[cfg(feature = "postgresql")]
67 Pool::PostgreSQL(pool) => {
68 let mut conn = pool.get()?;
69
70 let trans = conn.transaction()?;
71
72 let inner = Transaction::PostgreSQL(trans);
73
74 run(inner)?;
75 }
76 #[cfg(feature = "sqlite")]
77 Pool::SQLite(pool) => {
78 let mut conn = pool.get()?;
79
80 let trans = conn.transaction()?;
81
82 let inner = Transaction::SQLite(trans);
83
84 run(inner)?;
85 }
86 }
87
88 Ok(())
89 }
90
91 pub fn schema<T: Schema>(&self) -> Result<(), Error> {
92 match self {
93 #[cfg(feature = "postgresql")]
94 Pool::PostgreSQL(pool) => {
95 let mut conn = pool.get()?;
96
97 conn.batch_execute(T::schema_postgres())?;
98 }
99 #[cfg(feature = "sqlite")]
100 Pool::SQLite(pool) => {
101 let conn = pool.get()?;
102
103 conn.execute_batch(T::schema_sqlite())?;
104 }
105 }
106
107 Ok(())
108 }
109}
110
111#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
112pub enum PoolKind {
113 #[cfg(feature = "postgresql")]
114 PostgreSQL,
115 #[cfg(feature = "sqlite")]
116 SQLite,
117}
118
119impl From<Pool> for PoolKind {
120 fn from(pool: Pool) -> PoolKind {
121 pool.as_kind()
122 }
123}
124
125impl<'a> From<&'a Pool> for PoolKind {
126 fn from(pool: &'a Pool) -> PoolKind {
127 pool.as_kind()
128 }
129}
130
131#[cfg(feature = "postgresql")]
132pub mod postgres {
133 use {
134 postgres::{
135 tls::{MakeTlsConnect, TlsConnect},
136 Client, Config, Error, Socket,
137 },
138 r2d2::ManageConnection,
139 };
140
141 #[derive(Debug)]
142 pub struct PostgresConnectionManager<T> {
143 config: Config,
144 tls_connector: T,
145 }
146
147 impl<T> PostgresConnectionManager<T>
148 where
149 T: MakeTlsConnect<Socket> + Clone + 'static + Sync + Send,
150 T::TlsConnect: Send,
151 T::Stream: Send,
152 <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
153 {
154 pub fn new(config: Config, tls_connector: T) -> PostgresConnectionManager<T> {
156 PostgresConnectionManager {
157 config,
158 tls_connector,
159 }
160 }
161 }
162
163 impl<T> ManageConnection for PostgresConnectionManager<T>
164 where
165 T: MakeTlsConnect<Socket> + Clone + 'static + Sync + Send,
166 T::TlsConnect: Send,
167 T::Stream: Send,
168 <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
169 {
170 type Connection = Client;
171 type Error = Error;
172
173 fn connect(&self) -> Result<Client, Error> {
174 self.config.connect(self.tls_connector.clone())
175 }
176
177 fn is_valid(&self, client: &mut Client) -> Result<(), Error> {
178 client.simple_query("").map(|_| ())
179 }
180
181 fn has_broken(&self, client: &mut Client) -> bool {
182 client.is_closed()
183 }
184 }
185}
186
187#[cfg(feature = "sqlite")]
188pub mod sqlite {
189 use {
190 rusqlite::{Connection, Error, OpenFlags},
191 std::{
192 fmt,
193 path::{Path, PathBuf},
194 },
195 };
196
197 pub struct SqliteConnectionManager {
198 path: PathBuf,
199 }
200
201 impl fmt::Debug for SqliteConnectionManager {
202 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
203 let mut builder = f.debug_struct("SqliteConnectionManager");
204 let _ = builder.field("path", &self.path);
205 builder.finish()
206 }
207 }
208
209 impl SqliteConnectionManager {
210 pub fn file<P: AsRef<Path>>(path: P) -> Self {
211 Self {
212 path: path.as_ref().to_path_buf(),
213 }
214 }
215 }
216
217 impl r2d2::ManageConnection for SqliteConnectionManager {
218 type Connection = Connection;
219 type Error = rusqlite::Error;
220
221 fn connect(&self) -> Result<Connection, Error> {
222 Connection::open_with_flags(&self.path, OpenFlags::default()).map_err(Into::into)
223 }
224
225 fn is_valid(&self, conn: &mut Connection) -> Result<(), Error> {
226 conn.execute_batch("").map_err(Into::into)
227 }
228
229 fn has_broken(&self, _: &mut Connection) -> bool {
230 false
231 }
232 }
233}