1mod conversion;
2mod error;
3
4use crate::{
5 ast::{Query, Value},
6 connector::{metrics, queryable::*, ResultSet},
7 error::{Error, ErrorKind},
8 visitor::{self, Visitor},
9};
10use async_trait::async_trait;
11use lru_cache::LruCache;
12use mysql_async::{
13 self as my,
14 prelude::{Query as _, Queryable as _},
15};
16use percent_encoding::percent_decode;
17use std::{
18 borrow::Cow,
19 future::Future,
20 path::{Path, PathBuf},
21 sync::atomic::{AtomicBool, Ordering},
22 time::Duration,
23};
24use tokio::sync::Mutex;
25use url::Url;
26
27#[cfg(feature = "expose-drivers")]
30pub use mysql_async;
31
32use super::IsolationLevel;
33
34#[derive(Debug)]
36pub struct Mysql {
37 pub(crate) conn: Mutex<my::Conn>,
38 pub(crate) url: MysqlUrl,
39 socket_timeout: Option<Duration>,
40 is_healthy: AtomicBool,
41 statement_cache: Mutex<LruCache<String, my::Statement>>,
42}
43
44#[derive(Debug, Clone)]
46pub struct MysqlUrl {
47 url: Url,
48 query_params: MysqlUrlQueryParams,
49}
50
51impl MysqlUrl {
52 pub fn new(url: Url) -> Result<Self, Error> {
55 let query_params = Self::parse_query_params(&url)?;
56
57 Ok(Self { url, query_params })
58 }
59
60 pub fn url(&self) -> &Url {
62 &self.url
63 }
64
65 pub fn username(&self) -> Cow<str> {
67 match percent_decode(self.url.username().as_bytes()).decode_utf8() {
68 Ok(username) => username,
69 Err(_) => {
70 tracing::warn!("Couldn't decode username to UTF-8, using the non-decoded version.");
71
72 self.url.username().into()
73 }
74 }
75 }
76
77 pub fn password(&self) -> Option<Cow<str>> {
79 match self
80 .url
81 .password()
82 .and_then(|pw| percent_decode(pw.as_bytes()).decode_utf8().ok())
83 {
84 Some(password) => Some(password),
85 None => self.url.password().map(|s| s.into()),
86 }
87 }
88
89 pub fn dbname(&self) -> &str {
91 match self.url.path_segments() {
92 Some(mut segments) => segments.next().unwrap_or("mysql"),
93 None => "mysql",
94 }
95 }
96
97 pub fn host(&self) -> &str {
99 self.url.host_str().unwrap_or("localhost")
100 }
101
102 pub fn socket(&self) -> &Option<String> {
104 &self.query_params.socket
105 }
106
107 pub fn port(&self) -> u16 {
109 self.url.port().unwrap_or(3306)
110 }
111
112 pub fn connect_timeout(&self) -> Option<Duration> {
114 self.query_params.connect_timeout
115 }
116
117 pub fn pool_timeout(&self) -> Option<Duration> {
119 self.query_params.pool_timeout
120 }
121
122 pub fn socket_timeout(&self) -> Option<Duration> {
124 self.query_params.socket_timeout
125 }
126
127 pub fn prefer_socket(&self) -> Option<bool> {
129 self.query_params.prefer_socket
130 }
131
132 pub fn max_connection_lifetime(&self) -> Option<Duration> {
134 self.query_params.max_connection_lifetime
135 }
136
137 pub fn max_idle_connection_lifetime(&self) -> Option<Duration> {
139 self.query_params.max_idle_connection_lifetime
140 }
141
142 fn statement_cache_size(&self) -> usize {
143 self.query_params.statement_cache_size
144 }
145
146 pub(crate) fn cache(&self) -> LruCache<String, my::Statement> {
147 LruCache::new(self.query_params.statement_cache_size)
148 }
149
150 fn parse_query_params(url: &Url) -> Result<MysqlUrlQueryParams, Error> {
151 let mut ssl_opts = my::SslOpts::default();
152 ssl_opts = ssl_opts.with_danger_accept_invalid_certs(true);
153
154 let mut connection_limit = None;
155 let mut use_ssl = false;
156 let mut socket = None;
157 let mut socket_timeout = None;
158 let mut connect_timeout = Some(Duration::from_secs(5));
159 let mut pool_timeout = Some(Duration::from_secs(10));
160 let mut max_connection_lifetime = None;
161 let mut max_idle_connection_lifetime = Some(Duration::from_secs(300));
162 let mut prefer_socket = None;
163 let mut statement_cache_size = 100;
164 let mut identity: Option<(Option<PathBuf>, Option<String>)> = None;
165
166 for (k, v) in url.query_pairs() {
167 match k.as_ref() {
168 "connection_limit" => {
169 let as_int: usize = v
170 .parse()
171 .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?;
172
173 connection_limit = Some(as_int);
174 }
175 "statement_cache_size" => {
176 statement_cache_size = v
177 .parse()
178 .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?;
179 }
180 "sslcert" => {
181 use_ssl = true;
182 ssl_opts = ssl_opts.with_root_cert_path(Some(Path::new(&*v).to_path_buf()));
183 }
184 "sslidentity" => {
185 use_ssl = true;
186
187 identity = match identity {
188 Some((_, pw)) => Some((Some(Path::new(&*v).to_path_buf()), pw)),
189 None => Some((Some(Path::new(&*v).to_path_buf()), None)),
190 };
191 }
192 "sslpassword" => {
193 use_ssl = true;
194
195 identity = match identity {
196 Some((path, _)) => Some((path, Some(v.to_string()))),
197 None => Some((None, Some(v.to_string()))),
198 };
199 }
200 "socket" => {
201 socket = Some(v.replace(['(', ')'], ""));
202 }
203 "socket_timeout" => {
204 let as_int = v
205 .parse()
206 .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?;
207 socket_timeout = Some(Duration::from_secs(as_int));
208 }
209 "prefer_socket" => {
210 let as_bool = v
211 .parse::<bool>()
212 .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?;
213 prefer_socket = Some(as_bool)
214 }
215 "connect_timeout" => {
216 let as_int = v
217 .parse::<u64>()
218 .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?;
219
220 connect_timeout = match as_int {
221 0 => None,
222 _ => Some(Duration::from_secs(as_int)),
223 };
224 }
225 "pool_timeout" => {
226 let as_int = v
227 .parse::<u64>()
228 .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?;
229
230 pool_timeout = match as_int {
231 0 => None,
232 _ => Some(Duration::from_secs(as_int)),
233 };
234 }
235 "sslaccept" => {
236 use_ssl = true;
237 match v.as_ref() {
238 "strict" => {
239 ssl_opts = ssl_opts.with_danger_accept_invalid_certs(false);
240 }
241 "accept_invalid_certs" => {}
242 _ => {
243 tracing::debug!(
244 message = "Unsupported SSL accept mode, defaulting to `accept_invalid_certs`",
245 mode = &*v
246 );
247 }
248 };
249 }
250 "max_connection_lifetime" => {
251 let as_int = v
252 .parse()
253 .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?;
254
255 if as_int == 0 {
256 max_connection_lifetime = None;
257 } else {
258 max_connection_lifetime = Some(Duration::from_secs(as_int));
259 }
260 }
261 "max_idle_connection_lifetime" => {
262 let as_int = v
263 .parse()
264 .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?;
265
266 if as_int == 0 {
267 max_idle_connection_lifetime = None;
268 } else {
269 max_idle_connection_lifetime = Some(Duration::from_secs(as_int));
270 }
271 }
272 _ => {
273 tracing::trace!(message = "Discarding connection string param", param = &*k);
274 }
275 };
276 }
277
278 ssl_opts = match identity {
279 Some((Some(path), Some(pw))) => {
280 let identity = mysql_async::ClientIdentity::new(path).with_password(pw);
281 ssl_opts.with_client_identity(Some(identity))
282 }
283 Some((Some(path), None)) => {
284 let identity = mysql_async::ClientIdentity::new(path);
285 ssl_opts.with_client_identity(Some(identity))
286 }
287 _ => ssl_opts,
288 };
289
290 Ok(MysqlUrlQueryParams {
291 ssl_opts,
292 connection_limit,
293 use_ssl,
294 socket,
295 socket_timeout,
296 connect_timeout,
297 pool_timeout,
298 max_connection_lifetime,
299 max_idle_connection_lifetime,
300 prefer_socket,
301 statement_cache_size,
302 })
303 }
304
305 #[cfg(feature = "pooled")]
306 pub(crate) fn connection_limit(&self) -> Option<usize> {
307 self.query_params.connection_limit
308 }
309
310 pub(crate) fn to_opts_builder(&self) -> my::OptsBuilder {
311 let mut config = my::OptsBuilder::default()
312 .stmt_cache_size(Some(0))
313 .user(Some(self.username()))
314 .pass(self.password())
315 .db_name(Some(self.dbname()));
316
317 match self.socket() {
318 Some(ref socket) => {
319 config = config.socket(Some(socket));
320 }
321 None => {
322 config = config.ip_or_hostname(self.host()).tcp_port(self.port());
323 }
324 }
325
326 config = config.conn_ttl(Some(Duration::from_secs(5)));
327
328 if self.query_params.use_ssl {
329 config = config.ssl_opts(Some(self.query_params.ssl_opts.clone()));
330 }
331
332 if self.query_params.prefer_socket.is_some() {
333 config = config.prefer_socket(self.query_params.prefer_socket);
334 }
335
336 config
337 }
338}
339
340#[derive(Debug, Clone)]
341pub(crate) struct MysqlUrlQueryParams {
342 ssl_opts: my::SslOpts,
343 connection_limit: Option<usize>,
344 use_ssl: bool,
345 socket: Option<String>,
346 socket_timeout: Option<Duration>,
347 connect_timeout: Option<Duration>,
348 pool_timeout: Option<Duration>,
349 max_connection_lifetime: Option<Duration>,
350 max_idle_connection_lifetime: Option<Duration>,
351 prefer_socket: Option<bool>,
352 statement_cache_size: usize,
353}
354
355impl Mysql {
356 pub async fn new(url: MysqlUrl) -> crate::Result<Self> {
358 let conn = super::timeout::connect(url.connect_timeout(), my::Conn::new(url.to_opts_builder())).await?;
359
360 Ok(Self {
361 socket_timeout: url.query_params.socket_timeout,
362 conn: Mutex::new(conn),
363 statement_cache: Mutex::new(url.cache()),
364 url,
365 is_healthy: AtomicBool::new(true),
366 })
367 }
368
369 #[cfg(feature = "expose-drivers")]
373 pub fn conn(&self) -> &Mutex<mysql_async::Conn> {
374 &self.conn
375 }
376
377 async fn perform_io<F, U, T>(&self, op: U) -> crate::Result<T>
378 where
379 F: Future<Output = crate::Result<T>>,
380 U: FnOnce() -> F,
381 {
382 match super::timeout::socket(self.socket_timeout, op()).await {
383 Err(e) if e.is_closed() => {
384 self.is_healthy.store(false, Ordering::SeqCst);
385 Err(e)
386 }
387 res => Ok(res?),
388 }
389 }
390
391 async fn prepared<F, U, T>(&self, sql: &str, op: U) -> crate::Result<T>
392 where
393 F: Future<Output = crate::Result<T>>,
394 U: Fn(my::Statement) -> F,
395 {
396 if self.url.statement_cache_size() == 0 {
397 self.perform_io(|| async move {
398 let stmt = {
399 let mut conn = self.conn.lock().await;
400 conn.prep(sql).await?
401 };
402
403 let res = op(stmt.clone()).await;
404
405 {
406 let mut conn = self.conn.lock().await;
407 conn.close(stmt).await?;
408 }
409
410 res
411 })
412 .await
413 } else {
414 self.perform_io(|| async move {
415 let stmt = self.fetch_cached(sql).await?;
416 op(stmt).await
417 })
418 .await
419 }
420 }
421
422 async fn fetch_cached(&self, sql: &str) -> crate::Result<my::Statement> {
423 let mut cache = self.statement_cache.lock().await;
424 let capacity = cache.capacity();
425 let stored = cache.len();
426
427 match cache.get_mut(sql) {
428 Some(stmt) => {
429 tracing::trace!(
430 message = "CACHE HIT!",
431 query = sql,
432 capacity = capacity,
433 stored = stored,
434 );
435
436 Ok(stmt.clone()) }
438 None => {
439 tracing::trace!(
440 message = "CACHE MISS!",
441 query = sql,
442 capacity = capacity,
443 stored = stored,
444 );
445
446 let mut conn = self.conn.lock().await;
447 if cache.capacity() == cache.len() {
448 if let Some((_, stmt)) = cache.remove_lru() {
449 conn.close(stmt).await?;
450 }
451 }
452
453 let stmt = conn.prep(sql).await?;
454 cache.insert(sql.to_string(), stmt.clone());
455
456 Ok(stmt)
457 }
458 }
459 }
460}
461
462impl TransactionCapable for Mysql {}
463
464#[async_trait]
465impl Queryable for Mysql {
466 async fn query(&self, q: Query<'_>) -> crate::Result<ResultSet> {
467 let (sql, params) = visitor::Mysql::build(q)?;
468 self.query_raw(&sql, ¶ms).await
469 }
470
471 async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result<ResultSet> {
472 metrics::query("mysql.query_raw", sql, params, move || async move {
473 self.prepared(sql, |stmt| async move {
474 let mut conn = self.conn.lock().await;
475 let rows: Vec<my::Row> = conn.exec(&stmt, conversion::conv_params(params)?).await?;
476 let columns = stmt.columns().iter().map(|s| s.name_str().into_owned()).collect();
477
478 let last_id = conn.last_insert_id();
479 let mut result_set = ResultSet::new(columns, Vec::new());
480
481 for mut row in rows {
482 result_set.rows.push(row.take_result_row()?);
483 }
484
485 if let Some(id) = last_id {
486 result_set.set_last_insert_id(id);
487 };
488
489 Ok(result_set)
490 })
491 .await
492 })
493 .await
494 }
495
496 async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result<ResultSet> {
497 self.query_raw(sql, params).await
498 }
499
500 async fn execute(&self, q: Query<'_>) -> crate::Result<u64> {
501 let (sql, params) = visitor::Mysql::build(q)?;
502 self.execute_raw(&sql, ¶ms).await
503 }
504
505 async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result<u64> {
506 metrics::query("mysql.execute_raw", sql, params, move || async move {
507 self.prepared(sql, |stmt| async move {
508 let mut conn = self.conn.lock().await;
509 conn.exec_drop(stmt, conversion::conv_params(params)?).await?;
510
511 Ok(conn.affected_rows())
512 })
513 .await
514 })
515 .await
516 }
517
518 async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result<u64> {
519 self.execute_raw(sql, params).await
520 }
521
522 async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> {
523 metrics::query("mysql.raw_cmd", cmd, &[], move || async move {
524 self.perform_io(|| async move {
525 let mut conn = self.conn.lock().await;
526 let mut result = cmd.run(&mut *conn).await?;
527
528 loop {
529 result.map(drop).await?;
530
531 if result.is_empty() {
532 result.map(drop).await?;
533 break;
534 }
535 }
536
537 Ok(())
538 })
539 .await
540 })
541 .await
542 }
543
544 async fn version(&self) -> crate::Result<Option<String>> {
545 let query = r#"SELECT @@GLOBAL.version version"#;
546 let rows = super::timeout::socket(self.socket_timeout, self.query_raw(query, &[])).await?;
547
548 let version_string = rows
549 .get(0)
550 .and_then(|row| row.get("version").and_then(|version| version.to_string()));
551
552 Ok(version_string)
553 }
554
555 fn is_healthy(&self) -> bool {
556 self.is_healthy.load(Ordering::SeqCst)
557 }
558
559 async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> {
560 if matches!(isolation_level, IsolationLevel::Snapshot) {
561 return Err(Error::builder(ErrorKind::invalid_isolation_level(&isolation_level)).build());
562 }
563
564 self.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation_level}"))
565 .await?;
566
567 Ok(())
568 }
569
570 fn requires_isolation_first(&self) -> bool {
571 true
572 }
573}
574
575#[cfg(test)]
576mod tests {
577 use super::MysqlUrl;
578 use crate::tests::test_api::mysql::CONN_STR;
579 use crate::{error::*, single::Quaint};
580 use url::Url;
581
582 #[test]
583 fn should_parse_socket_url() {
584 let url = MysqlUrl::new(Url::parse("mysql://root@localhost/dbname?socket=(/tmp/mysql.sock)").unwrap()).unwrap();
585 assert_eq!("dbname", url.dbname());
586 assert_eq!(&Some(String::from("/tmp/mysql.sock")), url.socket());
587 }
588
589 #[test]
590 fn should_parse_prefer_socket() {
591 let url =
592 MysqlUrl::new(Url::parse("mysql://root:root@localhost:3307/testdb?prefer_socket=false").unwrap()).unwrap();
593 assert_eq!(false, url.prefer_socket().unwrap());
594 }
595
596 #[test]
597 fn should_parse_sslaccept() {
598 let url =
599 MysqlUrl::new(Url::parse("mysql://root:root@localhost:3307/testdb?sslaccept=strict").unwrap()).unwrap();
600 assert_eq!(true, url.query_params.use_ssl);
601 assert_eq!(false, url.query_params.ssl_opts.skip_domain_validation());
602 assert_eq!(false, url.query_params.ssl_opts.accept_invalid_certs());
603 }
604
605 #[test]
606 fn should_allow_changing_of_cache_size() {
607 let url = MysqlUrl::new(Url::parse("mysql:///root:root@localhost:3307/foo?statement_cache_size=420").unwrap())
608 .unwrap();
609 assert_eq!(420, url.cache().capacity());
610 }
611
612 #[test]
613 fn should_have_default_cache_size() {
614 let url = MysqlUrl::new(Url::parse("mysql:///root:root@localhost:3307/foo").unwrap()).unwrap();
615 assert_eq!(100, url.cache().capacity());
616 }
617
618 #[tokio::test]
619 async fn should_map_nonexisting_database_error() {
620 let mut url = Url::parse(&CONN_STR).unwrap();
621 url.set_username("root").unwrap();
622 url.set_path("/this_does_not_exist");
623
624 let url = url.as_str().to_string();
625 let res = Quaint::new(&url).await;
626
627 let err = res.unwrap_err();
628
629 match err.kind() {
630 ErrorKind::DatabaseDoesNotExist { db_name } => {
631 assert_eq!(Some("1049"), err.original_code());
632 assert_eq!(Some("Unknown database \'this_does_not_exist\'"), err.original_message());
633 assert_eq!(&Name::available("this_does_not_exist"), db_name)
634 }
635 e => panic!("Expected `DatabaseDoesNotExist`, got {:?}", e),
636 }
637 }
638
639 #[tokio::test]
640 async fn should_map_wrong_credentials_error() {
641 let mut url = Url::parse(&CONN_STR).unwrap();
642 url.set_username("WRONG").unwrap();
643
644 let res = Quaint::new(url.as_str()).await;
645 assert!(res.is_err());
646
647 let err = res.unwrap_err();
648 assert!(matches!(err.kind(), ErrorKind::AuthenticationFailed { user } if user == &Name::available("WRONG")));
649 }
650}