quaint_forked/connector/
mysql.rs

1mod conversion;
2mod error;
3
4use crate::{
5    ast::{Query, Value},
6    connector::{metrics, queryable::*, ResultSet},
7    error::{Error, ErrorKind},
8    visitor::{self, Visitor},
9};
10use async_trait::async_trait;
11use lru_cache::LruCache;
12use mysql_async::{
13    self as my,
14    prelude::{Query as _, Queryable as _},
15};
16use percent_encoding::percent_decode;
17use std::{
18    borrow::Cow,
19    future::Future,
20    path::{Path, PathBuf},
21    sync::atomic::{AtomicBool, Ordering},
22    time::Duration,
23};
24use tokio::sync::Mutex;
25use url::Url;
26
27/// The underlying MySQL driver. Only available with the `expose-drivers`
28/// Cargo feature.
29#[cfg(feature = "expose-drivers")]
30pub use mysql_async;
31
32use super::IsolationLevel;
33
34/// A connector interface for the MySQL database.
35#[derive(Debug)]
36pub struct Mysql {
37    pub(crate) conn: Mutex<my::Conn>,
38    pub(crate) url: MysqlUrl,
39    socket_timeout: Option<Duration>,
40    is_healthy: AtomicBool,
41    statement_cache: Mutex<LruCache<String, my::Statement>>,
42}
43
44/// Wraps a connection url and exposes the parsing logic used by quaint, including default values.
45#[derive(Debug, Clone)]
46pub struct MysqlUrl {
47    url: Url,
48    query_params: MysqlUrlQueryParams,
49}
50
51impl MysqlUrl {
52    /// Parse `Url` to `MysqlUrl`. Returns error for mistyped connection
53    /// parameters.
54    pub fn new(url: Url) -> Result<Self, Error> {
55        let query_params = Self::parse_query_params(&url)?;
56
57        Ok(Self { url, query_params })
58    }
59
60    /// The bare `Url` to the database.
61    pub fn url(&self) -> &Url {
62        &self.url
63    }
64
65    /// The percent-decoded database username.
66    pub fn username(&self) -> Cow<str> {
67        match percent_decode(self.url.username().as_bytes()).decode_utf8() {
68            Ok(username) => username,
69            Err(_) => {
70                tracing::warn!("Couldn't decode username to UTF-8, using the non-decoded version.");
71
72                self.url.username().into()
73            }
74        }
75    }
76
77    /// The percent-decoded database password.
78    pub fn password(&self) -> Option<Cow<str>> {
79        match self
80            .url
81            .password()
82            .and_then(|pw| percent_decode(pw.as_bytes()).decode_utf8().ok())
83        {
84            Some(password) => Some(password),
85            None => self.url.password().map(|s| s.into()),
86        }
87    }
88
89    /// Name of the database connected. Defaults to `mysql`.
90    pub fn dbname(&self) -> &str {
91        match self.url.path_segments() {
92            Some(mut segments) => segments.next().unwrap_or("mysql"),
93            None => "mysql",
94        }
95    }
96
97    /// The database host. If `socket` and `host` are not set, defaults to `localhost`.
98    pub fn host(&self) -> &str {
99        self.url.host_str().unwrap_or("localhost")
100    }
101
102    /// If set, connected to the database through a Unix socket.
103    pub fn socket(&self) -> &Option<String> {
104        &self.query_params.socket
105    }
106
107    /// The database port, defaults to `3306`.
108    pub fn port(&self) -> u16 {
109        self.url.port().unwrap_or(3306)
110    }
111
112    /// The connection timeout.
113    pub fn connect_timeout(&self) -> Option<Duration> {
114        self.query_params.connect_timeout
115    }
116
117    /// The pool check_out timeout
118    pub fn pool_timeout(&self) -> Option<Duration> {
119        self.query_params.pool_timeout
120    }
121
122    /// The socket timeout
123    pub fn socket_timeout(&self) -> Option<Duration> {
124        self.query_params.socket_timeout
125    }
126
127    /// Prefer socket connection
128    pub fn prefer_socket(&self) -> Option<bool> {
129        self.query_params.prefer_socket
130    }
131
132    /// The maximum connection lifetime
133    pub fn max_connection_lifetime(&self) -> Option<Duration> {
134        self.query_params.max_connection_lifetime
135    }
136
137    /// The maximum idle connection lifetime
138    pub fn max_idle_connection_lifetime(&self) -> Option<Duration> {
139        self.query_params.max_idle_connection_lifetime
140    }
141
142    fn statement_cache_size(&self) -> usize {
143        self.query_params.statement_cache_size
144    }
145
146    pub(crate) fn cache(&self) -> LruCache<String, my::Statement> {
147        LruCache::new(self.query_params.statement_cache_size)
148    }
149
150    fn parse_query_params(url: &Url) -> Result<MysqlUrlQueryParams, Error> {
151        let mut ssl_opts = my::SslOpts::default();
152        ssl_opts = ssl_opts.with_danger_accept_invalid_certs(true);
153
154        let mut connection_limit = None;
155        let mut use_ssl = false;
156        let mut socket = None;
157        let mut socket_timeout = None;
158        let mut connect_timeout = Some(Duration::from_secs(5));
159        let mut pool_timeout = Some(Duration::from_secs(10));
160        let mut max_connection_lifetime = None;
161        let mut max_idle_connection_lifetime = Some(Duration::from_secs(300));
162        let mut prefer_socket = None;
163        let mut statement_cache_size = 100;
164        let mut identity: Option<(Option<PathBuf>, Option<String>)> = None;
165
166        for (k, v) in url.query_pairs() {
167            match k.as_ref() {
168                "connection_limit" => {
169                    let as_int: usize = v
170                        .parse()
171                        .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?;
172
173                    connection_limit = Some(as_int);
174                }
175                "statement_cache_size" => {
176                    statement_cache_size = v
177                        .parse()
178                        .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?;
179                }
180                "sslcert" => {
181                    use_ssl = true;
182                    ssl_opts = ssl_opts.with_root_cert_path(Some(Path::new(&*v).to_path_buf()));
183                }
184                "sslidentity" => {
185                    use_ssl = true;
186
187                    identity = match identity {
188                        Some((_, pw)) => Some((Some(Path::new(&*v).to_path_buf()), pw)),
189                        None => Some((Some(Path::new(&*v).to_path_buf()), None)),
190                    };
191                }
192                "sslpassword" => {
193                    use_ssl = true;
194
195                    identity = match identity {
196                        Some((path, _)) => Some((path, Some(v.to_string()))),
197                        None => Some((None, Some(v.to_string()))),
198                    };
199                }
200                "socket" => {
201                    socket = Some(v.replace(['(', ')'], ""));
202                }
203                "socket_timeout" => {
204                    let as_int = v
205                        .parse()
206                        .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?;
207                    socket_timeout = Some(Duration::from_secs(as_int));
208                }
209                "prefer_socket" => {
210                    let as_bool = v
211                        .parse::<bool>()
212                        .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?;
213                    prefer_socket = Some(as_bool)
214                }
215                "connect_timeout" => {
216                    let as_int = v
217                        .parse::<u64>()
218                        .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?;
219
220                    connect_timeout = match as_int {
221                        0 => None,
222                        _ => Some(Duration::from_secs(as_int)),
223                    };
224                }
225                "pool_timeout" => {
226                    let as_int = v
227                        .parse::<u64>()
228                        .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?;
229
230                    pool_timeout = match as_int {
231                        0 => None,
232                        _ => Some(Duration::from_secs(as_int)),
233                    };
234                }
235                "sslaccept" => {
236                    use_ssl = true;
237                    match v.as_ref() {
238                        "strict" => {
239                            ssl_opts = ssl_opts.with_danger_accept_invalid_certs(false);
240                        }
241                        "accept_invalid_certs" => {}
242                        _ => {
243                            tracing::debug!(
244                                message = "Unsupported SSL accept mode, defaulting to `accept_invalid_certs`",
245                                mode = &*v
246                            );
247                        }
248                    };
249                }
250                "max_connection_lifetime" => {
251                    let as_int = v
252                        .parse()
253                        .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?;
254
255                    if as_int == 0 {
256                        max_connection_lifetime = None;
257                    } else {
258                        max_connection_lifetime = Some(Duration::from_secs(as_int));
259                    }
260                }
261                "max_idle_connection_lifetime" => {
262                    let as_int = v
263                        .parse()
264                        .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?;
265
266                    if as_int == 0 {
267                        max_idle_connection_lifetime = None;
268                    } else {
269                        max_idle_connection_lifetime = Some(Duration::from_secs(as_int));
270                    }
271                }
272                _ => {
273                    tracing::trace!(message = "Discarding connection string param", param = &*k);
274                }
275            };
276        }
277
278        ssl_opts = match identity {
279            Some((Some(path), Some(pw))) => {
280                let identity = mysql_async::ClientIdentity::new(path).with_password(pw);
281                ssl_opts.with_client_identity(Some(identity))
282            }
283            Some((Some(path), None)) => {
284                let identity = mysql_async::ClientIdentity::new(path);
285                ssl_opts.with_client_identity(Some(identity))
286            }
287            _ => ssl_opts,
288        };
289
290        Ok(MysqlUrlQueryParams {
291            ssl_opts,
292            connection_limit,
293            use_ssl,
294            socket,
295            socket_timeout,
296            connect_timeout,
297            pool_timeout,
298            max_connection_lifetime,
299            max_idle_connection_lifetime,
300            prefer_socket,
301            statement_cache_size,
302        })
303    }
304
305    #[cfg(feature = "pooled")]
306    pub(crate) fn connection_limit(&self) -> Option<usize> {
307        self.query_params.connection_limit
308    }
309
310    pub(crate) fn to_opts_builder(&self) -> my::OptsBuilder {
311        let mut config = my::OptsBuilder::default()
312            .stmt_cache_size(Some(0))
313            .user(Some(self.username()))
314            .pass(self.password())
315            .db_name(Some(self.dbname()));
316
317        match self.socket() {
318            Some(ref socket) => {
319                config = config.socket(Some(socket));
320            }
321            None => {
322                config = config.ip_or_hostname(self.host()).tcp_port(self.port());
323            }
324        }
325
326        config = config.conn_ttl(Some(Duration::from_secs(5)));
327
328        if self.query_params.use_ssl {
329            config = config.ssl_opts(Some(self.query_params.ssl_opts.clone()));
330        }
331
332        if self.query_params.prefer_socket.is_some() {
333            config = config.prefer_socket(self.query_params.prefer_socket);
334        }
335
336        config
337    }
338}
339
340#[derive(Debug, Clone)]
341pub(crate) struct MysqlUrlQueryParams {
342    ssl_opts: my::SslOpts,
343    connection_limit: Option<usize>,
344    use_ssl: bool,
345    socket: Option<String>,
346    socket_timeout: Option<Duration>,
347    connect_timeout: Option<Duration>,
348    pool_timeout: Option<Duration>,
349    max_connection_lifetime: Option<Duration>,
350    max_idle_connection_lifetime: Option<Duration>,
351    prefer_socket: Option<bool>,
352    statement_cache_size: usize,
353}
354
355impl Mysql {
356    /// Create a new MySQL connection using `OptsBuilder` from the `mysql` crate.
357    pub async fn new(url: MysqlUrl) -> crate::Result<Self> {
358        let conn = super::timeout::connect(url.connect_timeout(), my::Conn::new(url.to_opts_builder())).await?;
359
360        Ok(Self {
361            socket_timeout: url.query_params.socket_timeout,
362            conn: Mutex::new(conn),
363            statement_cache: Mutex::new(url.cache()),
364            url,
365            is_healthy: AtomicBool::new(true),
366        })
367    }
368
369    /// The underlying mysql_async::Conn. Only available with the
370    /// `expose-drivers` Cargo feature. This is a lower level API when you need
371    /// to get into database specific features.
372    #[cfg(feature = "expose-drivers")]
373    pub fn conn(&self) -> &Mutex<mysql_async::Conn> {
374        &self.conn
375    }
376
377    async fn perform_io<F, U, T>(&self, op: U) -> crate::Result<T>
378    where
379        F: Future<Output = crate::Result<T>>,
380        U: FnOnce() -> F,
381    {
382        match super::timeout::socket(self.socket_timeout, op()).await {
383            Err(e) if e.is_closed() => {
384                self.is_healthy.store(false, Ordering::SeqCst);
385                Err(e)
386            }
387            res => Ok(res?),
388        }
389    }
390
391    async fn prepared<F, U, T>(&self, sql: &str, op: U) -> crate::Result<T>
392    where
393        F: Future<Output = crate::Result<T>>,
394        U: Fn(my::Statement) -> F,
395    {
396        if self.url.statement_cache_size() == 0 {
397            self.perform_io(|| async move {
398                let stmt = {
399                    let mut conn = self.conn.lock().await;
400                    conn.prep(sql).await?
401                };
402
403                let res = op(stmt.clone()).await;
404
405                {
406                    let mut conn = self.conn.lock().await;
407                    conn.close(stmt).await?;
408                }
409
410                res
411            })
412            .await
413        } else {
414            self.perform_io(|| async move {
415                let stmt = self.fetch_cached(sql).await?;
416                op(stmt).await
417            })
418            .await
419        }
420    }
421
422    async fn fetch_cached(&self, sql: &str) -> crate::Result<my::Statement> {
423        let mut cache = self.statement_cache.lock().await;
424        let capacity = cache.capacity();
425        let stored = cache.len();
426
427        match cache.get_mut(sql) {
428            Some(stmt) => {
429                tracing::trace!(
430                    message = "CACHE HIT!",
431                    query = sql,
432                    capacity = capacity,
433                    stored = stored,
434                );
435
436                Ok(stmt.clone()) // arc'd
437            }
438            None => {
439                tracing::trace!(
440                    message = "CACHE MISS!",
441                    query = sql,
442                    capacity = capacity,
443                    stored = stored,
444                );
445
446                let mut conn = self.conn.lock().await;
447                if cache.capacity() == cache.len() {
448                    if let Some((_, stmt)) = cache.remove_lru() {
449                        conn.close(stmt).await?;
450                    }
451                }
452
453                let stmt = conn.prep(sql).await?;
454                cache.insert(sql.to_string(), stmt.clone());
455
456                Ok(stmt)
457            }
458        }
459    }
460}
461
462impl TransactionCapable for Mysql {}
463
464#[async_trait]
465impl Queryable for Mysql {
466    async fn query(&self, q: Query<'_>) -> crate::Result<ResultSet> {
467        let (sql, params) = visitor::Mysql::build(q)?;
468        self.query_raw(&sql, &params).await
469    }
470
471    async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result<ResultSet> {
472        metrics::query("mysql.query_raw", sql, params, move || async move {
473            self.prepared(sql, |stmt| async move {
474                let mut conn = self.conn.lock().await;
475                let rows: Vec<my::Row> = conn.exec(&stmt, conversion::conv_params(params)?).await?;
476                let columns = stmt.columns().iter().map(|s| s.name_str().into_owned()).collect();
477
478                let last_id = conn.last_insert_id();
479                let mut result_set = ResultSet::new(columns, Vec::new());
480
481                for mut row in rows {
482                    result_set.rows.push(row.take_result_row()?);
483                }
484
485                if let Some(id) = last_id {
486                    result_set.set_last_insert_id(id);
487                };
488
489                Ok(result_set)
490            })
491            .await
492        })
493        .await
494    }
495
496    async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result<ResultSet> {
497        self.query_raw(sql, params).await
498    }
499
500    async fn execute(&self, q: Query<'_>) -> crate::Result<u64> {
501        let (sql, params) = visitor::Mysql::build(q)?;
502        self.execute_raw(&sql, &params).await
503    }
504
505    async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result<u64> {
506        metrics::query("mysql.execute_raw", sql, params, move || async move {
507            self.prepared(sql, |stmt| async move {
508                let mut conn = self.conn.lock().await;
509                conn.exec_drop(stmt, conversion::conv_params(params)?).await?;
510
511                Ok(conn.affected_rows())
512            })
513            .await
514        })
515        .await
516    }
517
518    async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result<u64> {
519        self.execute_raw(sql, params).await
520    }
521
522    async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> {
523        metrics::query("mysql.raw_cmd", cmd, &[], move || async move {
524            self.perform_io(|| async move {
525                let mut conn = self.conn.lock().await;
526                let mut result = cmd.run(&mut *conn).await?;
527
528                loop {
529                    result.map(drop).await?;
530
531                    if result.is_empty() {
532                        result.map(drop).await?;
533                        break;
534                    }
535                }
536
537                Ok(())
538            })
539            .await
540        })
541        .await
542    }
543
544    async fn version(&self) -> crate::Result<Option<String>> {
545        let query = r#"SELECT @@GLOBAL.version version"#;
546        let rows = super::timeout::socket(self.socket_timeout, self.query_raw(query, &[])).await?;
547
548        let version_string = rows
549            .get(0)
550            .and_then(|row| row.get("version").and_then(|version| version.to_string()));
551
552        Ok(version_string)
553    }
554
555    fn is_healthy(&self) -> bool {
556        self.is_healthy.load(Ordering::SeqCst)
557    }
558
559    async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> {
560        if matches!(isolation_level, IsolationLevel::Snapshot) {
561            return Err(Error::builder(ErrorKind::invalid_isolation_level(&isolation_level)).build());
562        }
563
564        self.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation_level}"))
565            .await?;
566
567        Ok(())
568    }
569
570    fn requires_isolation_first(&self) -> bool {
571        true
572    }
573}
574
575#[cfg(test)]
576mod tests {
577    use super::MysqlUrl;
578    use crate::tests::test_api::mysql::CONN_STR;
579    use crate::{error::*, single::Quaint};
580    use url::Url;
581
582    #[test]
583    fn should_parse_socket_url() {
584        let url = MysqlUrl::new(Url::parse("mysql://root@localhost/dbname?socket=(/tmp/mysql.sock)").unwrap()).unwrap();
585        assert_eq!("dbname", url.dbname());
586        assert_eq!(&Some(String::from("/tmp/mysql.sock")), url.socket());
587    }
588
589    #[test]
590    fn should_parse_prefer_socket() {
591        let url =
592            MysqlUrl::new(Url::parse("mysql://root:root@localhost:3307/testdb?prefer_socket=false").unwrap()).unwrap();
593        assert_eq!(false, url.prefer_socket().unwrap());
594    }
595
596    #[test]
597    fn should_parse_sslaccept() {
598        let url =
599            MysqlUrl::new(Url::parse("mysql://root:root@localhost:3307/testdb?sslaccept=strict").unwrap()).unwrap();
600        assert_eq!(true, url.query_params.use_ssl);
601        assert_eq!(false, url.query_params.ssl_opts.skip_domain_validation());
602        assert_eq!(false, url.query_params.ssl_opts.accept_invalid_certs());
603    }
604
605    #[test]
606    fn should_allow_changing_of_cache_size() {
607        let url = MysqlUrl::new(Url::parse("mysql:///root:root@localhost:3307/foo?statement_cache_size=420").unwrap())
608            .unwrap();
609        assert_eq!(420, url.cache().capacity());
610    }
611
612    #[test]
613    fn should_have_default_cache_size() {
614        let url = MysqlUrl::new(Url::parse("mysql:///root:root@localhost:3307/foo").unwrap()).unwrap();
615        assert_eq!(100, url.cache().capacity());
616    }
617
618    #[tokio::test]
619    async fn should_map_nonexisting_database_error() {
620        let mut url = Url::parse(&CONN_STR).unwrap();
621        url.set_username("root").unwrap();
622        url.set_path("/this_does_not_exist");
623
624        let url = url.as_str().to_string();
625        let res = Quaint::new(&url).await;
626
627        let err = res.unwrap_err();
628
629        match err.kind() {
630            ErrorKind::DatabaseDoesNotExist { db_name } => {
631                assert_eq!(Some("1049"), err.original_code());
632                assert_eq!(Some("Unknown database \'this_does_not_exist\'"), err.original_message());
633                assert_eq!(&Name::available("this_does_not_exist"), db_name)
634            }
635            e => panic!("Expected `DatabaseDoesNotExist`, got {:?}", e),
636        }
637    }
638
639    #[tokio::test]
640    async fn should_map_wrong_credentials_error() {
641        let mut url = Url::parse(&CONN_STR).unwrap();
642        url.set_username("WRONG").unwrap();
643
644        let res = Quaint::new(url.as_str()).await;
645        assert!(res.is_err());
646
647        let err = res.unwrap_err();
648        assert!(matches!(err.kind(), ErrorKind::AuthenticationFailed { user } if user == &Name::available("WRONG")));
649    }
650}