Skip to main content

refinery_core/drivers/
config.rs

1use crate::config::Config;
2use crate::traits::r#async::{AsyncQuery, AsyncTransaction};
3use crate::traits::sync::{Query, Transaction};
4use crate::Migration;
5#[cfg(any(
6    feature = "mysql",
7    feature = "postgres",
8    feature = "rusqlite",
9    feature = "tokio-postgres",
10    feature = "mysql_async",
11    feature = "tiberius-config"
12))]
13use crate::{
14    config::ConfigDbType,
15    error::WrapMigrationError,
16    traits::{GET_APPLIED_MIGRATIONS_QUERY, GET_LAST_APPLIED_MIGRATION_QUERY},
17    Error, Report, Target,
18};
19use async_trait::async_trait;
20use std::convert::Infallible;
21
22// we impl all the dependent traits as noop's and then override the methods that call them on Migrate and AsyncMigrate
23impl Transaction for Config {
24    type Error = Infallible;
25
26    fn execute<'a, T: Iterator<Item = &'a str>>(
27        &mut self,
28        _queries: T,
29    ) -> Result<usize, Self::Error> {
30        Ok(0)
31    }
32}
33
34impl Query<Vec<Migration>> for Config {
35    fn query(&mut self, _query: &str) -> Result<Vec<Migration>, Self::Error> {
36        Ok(Vec::new())
37    }
38}
39
40#[async_trait]
41impl AsyncTransaction for Config {
42    type Error = Infallible;
43
44    async fn execute<'a, T: Iterator<Item = &'a str> + Send>(
45        &mut self,
46        _queries: T,
47    ) -> Result<usize, Self::Error> {
48        Ok(0)
49    }
50}
51
52#[async_trait]
53impl AsyncQuery<Vec<Migration>> for Config {
54    async fn query(
55        &mut self,
56        _query: &str,
57    ) -> Result<Vec<Migration>, <Self as AsyncTransaction>::Error> {
58        Ok(Vec::new())
59    }
60}
61// this is written as macro so that we don't have to deal with type signatures
62#[cfg(any(feature = "mysql", feature = "postgres", feature = "rusqlite"))]
63#[allow(clippy::redundant_closure_call)]
64macro_rules! with_connection {
65    ($config:ident, $op: expr) => {
66        #[allow(clippy::redundant_closure_call)]
67        match $config.db_type() {
68            ConfigDbType::Mysql => {
69                cfg_if::cfg_if! {
70                    if #[cfg(feature = "mysql")] {
71                        let url = crate::config::build_db_url("mysql", &$config);
72                        let opts = mysql::Opts::from_url(&url).migration_err("could not parse url", None)?;
73                        let conn = mysql::Conn::new(opts).migration_err("could not connect to database", None)?;
74                        $op(conn)
75                    } else {
76                        panic!("tried to migrate from config for a mysql database, but feature mysql not enabled!");
77                    }
78                }
79            }
80            ConfigDbType::Sqlite => {
81                cfg_if::cfg_if! {
82                    if #[cfg(feature = "rusqlite")] {
83                        //may have been checked earlier on config parsing, even if not let it fail with a Rusqlite db file not found error
84                        let path = $config.db_path().map(|p| p.to_path_buf()).unwrap_or_default();
85                        let conn = rusqlite::Connection::open_with_flags(path, rusqlite::OpenFlags::SQLITE_OPEN_READ_WRITE).migration_err("could not open database", None)?;
86                        $op(conn)
87                    } else {
88                        panic!("tried to migrate from config for a sqlite database, but feature rusqlite not enabled!");
89                    }
90                }
91            }
92            ConfigDbType::Postgres => {
93                cfg_if::cfg_if! {
94                    if #[cfg(feature = "postgres")] {
95                        let path = crate::config::build_db_url("postgresql", &$config);
96
97                        let conn;
98                        cfg_if::cfg_if! {
99                            if #[cfg(feature = "tls")] {
100                                if $config.use_tls() {
101                                    let connector = native_tls::TlsConnector::new().unwrap();
102                                    let connector = postgres_native_tls::MakeTlsConnector::new(connector);
103                                    conn = postgres::Client::connect(path.as_str(), connector).migration_err("could not connect to database", None)?;
104                                } else {
105                                    conn = postgres::Client::connect(path.as_str(), postgres::NoTls).migration_err("could not connect to database", None)?;
106                                }
107                            } else if #[cfg(feature = "tokio-postgres-rustls")] {
108                                if $config.use_tls() {
109                                    panic!("tokio-postgres-rustls only supports the async tokio-postgres driver; enable the 'tls' feature to use TLS with the sync postgres driver");
110                                }
111                                conn = postgres::Client::connect(path.as_str(), postgres::NoTls).migration_err("could not connect to database", None)?;
112                            } else {
113                                if $config.use_tls() {
114                                    panic!("TLS was requested but neither 'tls' nor 'tokio-postgres-rustls' feature is enabled in refinery-core");
115                                }
116                                conn = postgres::Client::connect(path.as_str(), postgres::NoTls).migration_err("could not connect to database", None)?;
117                            }
118                        }
119
120                        $op(conn)
121                    } else {
122                        panic!("tried to migrate from config for a postgresql database, but feature postgres not enabled!");
123                    }
124                }
125            }
126            ConfigDbType::Mssql => {
127                panic!("tried to synchronously migrate from config for a mssql database, but tiberius is an async driver");
128            }
129        }
130    }
131}
132
133#[cfg(any(
134    feature = "tokio-postgres",
135    feature = "mysql_async",
136    feature = "tiberius-config"
137))]
138macro_rules! with_connection_async {
139    ($config: ident, $op: expr) => {
140        #[allow(clippy::redundant_closure_call)]
141        match $config.db_type() {
142            ConfigDbType::Mysql => {
143                cfg_if::cfg_if! {
144                    if #[cfg(feature = "mysql_async")] {
145                        let url = crate::config::build_db_url("mysql", $config);
146                        let pool = mysql_async::Pool::from_url(&url).migration_err("could not connect to the database", None)?;
147                        $op(pool).await
148                    } else {
149                        panic!("tried to migrate async from config for a mysql database, but feature mysql_async not enabled!");
150                    }
151                }
152            }
153            ConfigDbType::Sqlite => {
154                panic!("tried to migrate async from config for a sqlite database, but this feature is not implemented yet");
155            }
156            ConfigDbType::Postgres => {
157                cfg_if::cfg_if! {
158                    if #[cfg(feature = "tokio-postgres")] {
159                        let path = crate::config::build_db_url("postgresql", $config);
160                        cfg_if::cfg_if! {
161                            if #[cfg(feature = "tls")] {
162                                if $config.use_tls() {
163                                    let connector = native_tls::TlsConnector::new().unwrap();
164                                    let connector = postgres_native_tls::MakeTlsConnector::new(connector);
165                                    let (client, connection) = tokio_postgres::connect(path.as_str(), connector).await.migration_err("could not connect to database", None)?;
166                                    tokio::spawn(async move {
167                                        if let Err(e) = connection.await {
168                                            eprintln!("connection error: {}", e);
169                                        }
170                                    });
171                                    $op(client).await
172                                } else {
173                                    let (client, connection) = tokio_postgres::connect(path.as_str(), tokio_postgres::NoTls).await.migration_err("could not connect to database", None)?;
174                                    tokio::spawn(async move {
175                                        if let Err(e) = connection.await {
176                                            eprintln!("connection error: {}", e);
177                                        }
178                                    });
179                                    $op(client).await
180                                }
181                            } else if #[cfg(feature = "tokio-postgres-rustls")] {
182                                if $config.use_tls() {
183                                    let native_certs = rustls_native_certs::load_native_certs();
184                                    for err in &native_certs.errors {
185                                        log::warn!("Failed to load native TLS certificate: {err}");
186                                    }
187                                    let mut root_store = rustls::RootCertStore::empty();
188                                    root_store.add_parsable_certificates(native_certs.certs);
189                                    let tls_config = rustls::ClientConfig::builder()
190                                        .with_root_certificates(root_store)
191                                        .with_no_client_auth();
192                                    let tls = tokio_postgres_rustls::MakeRustlsConnect::new(tls_config);
193                                    let (client, connection) = tokio_postgres::connect(path.as_str(), tls).await.migration_err("could not connect to database", None)?;
194                                    tokio::spawn(async move {
195                                        if let Err(e) = connection.await {
196                                            eprintln!("connection error: {}", e);
197                                        }
198                                    });
199                                    $op(client).await
200                                } else {
201                                    let (client, connection) = tokio_postgres::connect(path.as_str(), tokio_postgres::NoTls).await.migration_err("could not connect to database", None)?;
202                                    tokio::spawn(async move {
203                                        if let Err(e) = connection.await {
204                                            eprintln!("connection error: {}", e);
205                                        }
206                                    });
207                                    $op(client).await
208                                }
209                            } else {
210                                if $config.use_tls() {
211                                    panic!("TLS was requested but neither 'tls' nor 'tokio-postgres-rustls' feature is enabled in refinery-core");
212                                }
213                                let (client, connection) = tokio_postgres::connect(path.as_str(), tokio_postgres::NoTls).await.migration_err("could not connect to database", None)?;
214                                tokio::spawn(async move {
215                                    if let Err(e) = connection.await {
216                                        eprintln!("connection error: {}", e);
217                                    }
218                                });
219                                $op(client).await
220                            }
221                        }
222                    } else {
223                        panic!("tried to migrate async from config for a postgresql database, but tokio-postgres was not enabled!");
224                    }
225                }
226            }
227            ConfigDbType::Mssql => {
228                cfg_if::cfg_if! {
229                    if #[cfg(feature = "tiberius-config")] {
230                        use tiberius::{Client, Config};
231                        use tokio::net::TcpStream;
232                        use tokio_util::compat::TokioAsyncWriteCompatExt;
233                        use std::convert::TryInto;
234
235                        let config: Config = (&*$config).try_into()?;
236                        let tcp = TcpStream::connect(config.get_addr())
237                            .await
238                            .migration_err("could not connect to database", None)?;
239                        let client = Client::connect(config, tcp.compat_write())
240                            .await
241                            .migration_err("could not connect to database", None)?;
242
243                        $op(client).await
244                    } else {
245                        panic!("tried to migrate async from config for a mssql database, but tiberius-config feature was not enabled!");
246                    }
247                }
248            }
249        }
250    }
251}
252
253// rewrite all the default methods as we overrode Transaction and Query
254#[cfg(any(feature = "mysql", feature = "postgres", feature = "rusqlite"))]
255impl crate::Migrate for Config {
256    fn get_last_applied_migration(
257        &mut self,
258        migration_table_name: &str,
259    ) -> Result<Option<Migration>, Error> {
260        with_connection!(self, |mut conn| {
261            let mut migrations: Vec<Migration> = Query::query(
262                &mut conn,
263                &GET_LAST_APPLIED_MIGRATION_QUERY
264                    .replace("%MIGRATION_TABLE_NAME%", migration_table_name),
265            )
266            .migration_err("error getting last applied migration", None)?;
267
268            Ok(migrations.pop())
269        })
270    }
271
272    fn get_applied_migrations(
273        &mut self,
274        migration_table_name: &str,
275    ) -> Result<Vec<Migration>, Error> {
276        with_connection!(self, |mut conn| {
277            let migrations: Vec<Migration> = Query::query(
278                &mut conn,
279                &GET_APPLIED_MIGRATIONS_QUERY
280                    .replace("%MIGRATION_TABLE_NAME%", migration_table_name),
281            )
282            .migration_err("error getting applied migrations", None)?;
283
284            Ok(migrations)
285        })
286    }
287
288    fn migrate(
289        &mut self,
290        migrations: &[Migration],
291        abort_divergent: bool,
292        abort_missing: bool,
293        grouped: bool,
294        target: Target,
295        migration_table_name: &str,
296    ) -> Result<Report, Error> {
297        with_connection!(self, |mut conn| {
298            crate::Migrate::migrate(
299                &mut conn,
300                migrations,
301                abort_divergent,
302                abort_missing,
303                grouped,
304                target,
305                migration_table_name,
306            )
307        })
308    }
309}
310
311#[cfg(any(
312    feature = "mysql_async",
313    feature = "tokio-postgres",
314    feature = "tiberius-config"
315))]
316#[async_trait]
317impl crate::AsyncMigrate for Config {
318    async fn get_last_applied_migration(
319        &mut self,
320        migration_table_name: &str,
321    ) -> Result<Option<Migration>, Error> {
322        with_connection_async!(self, move |mut conn| async move {
323            let mut migrations: Vec<Migration> = AsyncQuery::query(
324                &mut conn,
325                &GET_LAST_APPLIED_MIGRATION_QUERY
326                    .replace("%MIGRATION_TABLE_NAME%", migration_table_name),
327            )
328            .await
329            .migration_err("error getting last applied migration", None)?;
330
331            Ok(migrations.pop())
332        })
333    }
334
335    async fn get_applied_migrations(
336        &mut self,
337        migration_table_name: &str,
338    ) -> Result<Vec<Migration>, Error> {
339        with_connection_async!(self, move |mut conn| async move {
340            let migrations: Vec<Migration> = AsyncQuery::query(
341                &mut conn,
342                &GET_APPLIED_MIGRATIONS_QUERY
343                    .replace("%MIGRATION_TABLE_NAME%", migration_table_name),
344            )
345            .await
346            .migration_err("error getting last applied migration", None)?;
347            Ok(migrations)
348        })
349    }
350
351    async fn migrate(
352        &mut self,
353        migrations: &[Migration],
354        abort_divergent: bool,
355        abort_missing: bool,
356        grouped: bool,
357        target: Target,
358        migration_table_name: &str,
359    ) -> Result<Report, Error> {
360        with_connection_async!(self, move |mut conn| async move {
361            crate::AsyncMigrate::migrate(
362                &mut conn,
363                migrations,
364                abort_divergent,
365                abort_missing,
366                grouped,
367                target,
368                migration_table_name,
369            )
370            .await
371        })
372    }
373}