mysql/conn/opts/
mod.rs

1// Copyright (c) 2020 rust-mysql-simple contributors
2//
3// Licensed under the Apache License, Version 2.0
4// <LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0> or the MIT
5// license <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. All files in the project carrying such notice may not be copied,
7// modified, or distributed except according to those terms.
8
9use percent_encoding::percent_decode;
10use url::Url;
11
12use std::{
13    borrow::Cow, collections::HashMap, hash::Hash, net::SocketAddr, path::Path, time::Duration,
14};
15
16use crate::{consts::CapabilityFlags, Compression, LocalInfileHandler, UrlError};
17
18/// Default value for client side per-connection statement cache.
19pub const DEFAULT_STMT_CACHE_SIZE: usize = 32;
20
21mod native_tls_opts;
22mod rustls_opts;
23
24#[cfg(feature = "native-tls")]
25pub use native_tls_opts::ClientIdentity;
26
27#[cfg(feature = "rustls-tls")]
28pub use rustls_opts::ClientIdentity;
29
30/// Ssl Options.
31#[derive(Debug, Clone, Eq, PartialEq, Hash, Default)]
32pub struct SslOpts {
33    #[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
34    client_identity: Option<ClientIdentity>,
35    root_cert_path: Option<Cow<'static, Path>>,
36    skip_domain_validation: bool,
37    accept_invalid_certs: bool,
38}
39
40impl SslOpts {
41    /// Sets the client identity.
42    #[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
43    pub fn with_client_identity(mut self, identity: Option<ClientIdentity>) -> Self {
44        self.client_identity = identity;
45        self
46    }
47
48    /// Sets path to a certificate of the root that connector will trust.
49    ///
50    /// Supported certificate formats are .der and .pem.
51    /// Multiple certs are allowed in .pem files.
52    pub fn with_root_cert_path<T: Into<Cow<'static, Path>>>(
53        mut self,
54        root_cert_path: Option<T>,
55    ) -> Self {
56        self.root_cert_path = root_cert_path.map(Into::into);
57        self
58    }
59
60    /// The way to not validate the server's domain
61    /// name against its certificate (defaults to `false`).
62    pub fn with_danger_skip_domain_validation(mut self, value: bool) -> Self {
63        self.skip_domain_validation = value;
64        self
65    }
66
67    /// If `true` then client will accept invalid certificate (expired, not trusted, ..)
68    /// (defaults to `false`).
69    pub fn with_danger_accept_invalid_certs(mut self, value: bool) -> Self {
70        self.accept_invalid_certs = value;
71        self
72    }
73
74    #[cfg(any(feature = "native-tls", feature = "rustls-tls"))]
75    pub fn client_identity(&self) -> Option<&ClientIdentity> {
76        self.client_identity.as_ref()
77    }
78
79    pub fn root_cert_path(&self) -> Option<&Path> {
80        self.root_cert_path.as_ref().map(AsRef::as_ref)
81    }
82
83    pub fn skip_domain_validation(&self) -> bool {
84        self.skip_domain_validation
85    }
86
87    pub fn accept_invalid_certs(&self) -> bool {
88        self.accept_invalid_certs
89    }
90}
91
92/// Options structure is quite large so we'll store it separately.
93#[derive(Debug, Clone, Eq, PartialEq)]
94pub(crate) struct InnerOpts {
95    /// Address of mysql server (defaults to `127.0.0.1`). Hostnames should also work.
96    ip_or_hostname: url::Host,
97    /// TCP port of mysql server (defaults to `3306`).
98    tcp_port: u16,
99    /// Path to unix socket on unix or pipe name on windows (defaults to `None`).
100    ///
101    /// Can be defined using `socket` connection url parameter.
102    socket: Option<String>,
103    /// User (defaults to `None`).
104    user: Option<String>,
105    /// Password (defaults to `None`).
106    pass: Option<String>,
107    /// Database name (defaults to `None`).
108    db_name: Option<String>,
109
110    /// The timeout for each attempt to read from the server.
111    read_timeout: Option<Duration>,
112
113    /// The timeout for each attempt to write to the server.
114    write_timeout: Option<Duration>,
115
116    /// Prefer socket connection (defaults to `true`).
117    ///
118    /// Will reconnect via socket (or named pipe on windows) after TCP
119    /// connection to `127.0.0.1` if `true`.
120    ///
121    /// Will fall back to TCP on error. Use `socket` option to enforce socket connection.
122    ///
123    /// Can be defined using `prefer_socket` connection url parameter.
124    prefer_socket: bool,
125
126    /// Whether to enable `TCP_NODELAY` (defaults to `true`).
127    ///
128    /// This option disables Nagle's algorithm, which can cause unusually high latency (~40ms) at
129    /// some cost to maximum throughput. See #132.
130    tcp_nodelay: bool,
131
132    /// TCP keep alive time for mysql connection.
133    ///
134    /// Can be defined using `tcp_keepalive_time_ms` connection url parameter.
135    tcp_keepalive_time: Option<u32>,
136
137    /// TCP keep alive interval between subsequent probe for mysql connection.
138    ///
139    /// Can be defined using `tcp_keepalive_probe_interval_secs` connection url parameter.
140    #[cfg(any(target_os = "linux", target_os = "macos",))]
141    tcp_keepalive_probe_interval_secs: Option<u32>,
142
143    /// TCP keep alive probe count for mysql connection.
144    ///
145    /// Can be defined using `tcp_keepalive_probe_count` connection url parameter.
146    #[cfg(any(target_os = "linux", target_os = "macos",))]
147    tcp_keepalive_probe_count: Option<u32>,
148
149    /// TCP_USER_TIMEOUT time for mysql connection.
150    ///
151    /// Can be defined using `tcp_user_timeout_ms` connection url parameter.
152    #[cfg(target_os = "linux")]
153    tcp_user_timeout: Option<u32>,
154
155    /// Commands to execute on each new database connection.
156    init: Vec<String>,
157
158    /// Driver will require SSL connection if this option isn't `None` (default to `None`).
159    ssl_opts: Option<SslOpts>,
160
161    /// Callback to handle requests for local files.
162    ///
163    /// These are caused by using `LOAD DATA LOCAL INFILE` queries.
164    /// The callback is passed the filename, and a `Write`able object
165    /// to receive the contents of that file.
166    ///
167    /// If unset, the default callback will read files relative to
168    /// the current directory.
169    local_infile_handler: Option<LocalInfileHandler>,
170
171    /// Tcp connect timeout (defaults to `None`).
172    ///
173    /// Can be defined using `tcp_connect_timeout_ms` connection url parameter.
174    tcp_connect_timeout: Option<Duration>,
175
176    /// Bind address for a client (defaults to `None`).
177    ///
178    /// Use carefully. Will probably make pool unusable because of *address already in use*
179    /// errors.
180    bind_address: Option<SocketAddr>,
181
182    /// Number of prepared statements cached on the client side (per connection).
183    /// Defaults to [`DEFAULT_STMT_CACHE_SIZE`].
184    ///
185    /// Can be defined using `stmt_cache_size` connection url parameter.
186    stmt_cache_size: usize,
187
188    /// If not `None`, then client will ask for compression if server supports it
189    /// (defaults to `None`).
190    ///
191    /// Can be defined using `compress` connection url parameter with values `true`, `fast`, `best`,
192    /// `0`, `1`, ..., `9`.
193    ///
194    /// Note that compression level defined here will affect only outgoing packets.
195    compress: Option<crate::Compression>,
196
197    /// Additional client capabilities to set (defaults to empty).
198    ///
199    /// This value will be OR'ed with other client capabilities during connection initialisation.
200    ///
201    /// ### Note
202    ///
203    /// It is a good way to set something like `CLIENT_FOUND_ROWS` but you should note that it
204    /// won't let you to interfere with capabilities managed by other options (like
205    /// `CLIENT_SSL` or `CLIENT_COMPRESS`). Also note that some capabilities are reserved,
206    /// pointless or may broke the connection, so this option should be used with caution.
207    additional_capabilities: CapabilityFlags,
208
209    /// Connect attributes
210    connect_attrs: HashMap<String, String>,
211
212    /// Disables `mysql_old_password` plugin (defaults to `true`).
213    ///
214    /// Available via `secure_auth` connection url parameter.
215    secure_auth: bool,
216
217    /// For tests only
218    #[cfg(test)]
219    pub injected_socket: Option<String>,
220}
221
222impl Default for InnerOpts {
223    fn default() -> Self {
224        InnerOpts {
225            ip_or_hostname: url::Host::Domain(String::from("localhost")),
226            tcp_port: 3306,
227            socket: None,
228            user: None,
229            pass: None,
230            db_name: None,
231            read_timeout: None,
232            write_timeout: None,
233            prefer_socket: true,
234            init: vec![],
235            ssl_opts: None,
236            tcp_keepalive_time: None,
237            #[cfg(any(target_os = "linux", target_os = "macos",))]
238            tcp_keepalive_probe_interval_secs: None,
239            #[cfg(any(target_os = "linux", target_os = "macos",))]
240            tcp_keepalive_probe_count: None,
241            #[cfg(target_os = "linux")]
242            tcp_user_timeout: None,
243            tcp_nodelay: true,
244            local_infile_handler: None,
245            tcp_connect_timeout: None,
246            bind_address: None,
247            stmt_cache_size: DEFAULT_STMT_CACHE_SIZE,
248            compress: None,
249            additional_capabilities: CapabilityFlags::empty(),
250            connect_attrs: HashMap::new(),
251            secure_auth: true,
252            #[cfg(test)]
253            injected_socket: None,
254        }
255    }
256}
257
258impl TryFrom<&'_ str> for Opts {
259    type Error = UrlError;
260
261    fn try_from(url: &'_ str) -> Result<Self, Self::Error> {
262        Opts::from_url(url)
263    }
264}
265
266/// Mysql connection options.
267///
268/// Build one with [`OptsBuilder`](struct.OptsBuilder.html).
269#[derive(Clone, Eq, PartialEq, Debug, Default)]
270pub struct Opts(pub(crate) Box<InnerOpts>);
271
272impl Opts {
273    #[doc(hidden)]
274    pub fn addr_is_loopback(&self) -> bool {
275        match self.0.ip_or_hostname {
276            url::Host::Domain(ref name) => name == "localhost",
277            url::Host::Ipv4(ref addr) => addr.is_loopback(),
278            url::Host::Ipv6(ref addr) => addr.is_loopback(),
279        }
280    }
281
282    pub fn from_url(url: &str) -> Result<Opts, UrlError> {
283        from_url(url)
284    }
285
286    pub(crate) fn get_host(&self) -> url::Host {
287        self.0.ip_or_hostname.clone()
288    }
289
290    /// Address of mysql server (defaults to `127.0.0.1`). Hostnames should also work.
291    pub fn get_ip_or_hostname(&self) -> Cow<str> {
292        self.0.ip_or_hostname.to_string().into()
293    }
294    /// TCP port of mysql server (defaults to `3306`).
295    pub fn get_tcp_port(&self) -> u16 {
296        self.0.tcp_port
297    }
298    /// Socket path on unix or pipe name on windows (defaults to `None`).
299    pub fn get_socket(&self) -> Option<&str> {
300        self.0.socket.as_deref()
301    }
302    /// User (defaults to `None`).
303    pub fn get_user(&self) -> Option<&str> {
304        self.0.user.as_deref()
305    }
306    /// Password (defaults to `None`).
307    pub fn get_pass(&self) -> Option<&str> {
308        self.0.pass.as_deref()
309    }
310    /// Database name (defaults to `None`).
311    pub fn get_db_name(&self) -> Option<&str> {
312        self.0.db_name.as_deref()
313    }
314
315    /// The timeout for each attempt to write to the server.
316    pub fn get_read_timeout(&self) -> Option<&Duration> {
317        self.0.read_timeout.as_ref()
318    }
319
320    /// The timeout for each attempt to write to the server.
321    pub fn get_write_timeout(&self) -> Option<&Duration> {
322        self.0.write_timeout.as_ref()
323    }
324
325    /// Prefer socket connection (defaults to `true`).
326    ///
327    /// Will reconnect via socket (or named pipe on windows) after TCP connection
328    /// to `127.0.0.1` if `true`.
329    ///
330    /// Will fall back to TCP on error. Use `socket` option to enforce socket connection.
331    pub fn get_prefer_socket(&self) -> bool {
332        self.0.prefer_socket
333    }
334    // XXX: Wait for keepalive_timeout stabilization
335    /// Commands to execute on each new database connection.
336    pub fn get_init(&self) -> Vec<String> {
337        self.0.init.clone()
338    }
339
340    /// Driver will require SSL connection if this option isn't `None` (default to `None`).
341    pub fn get_ssl_opts(&self) -> Option<&SslOpts> {
342        self.0.ssl_opts.as_ref()
343    }
344
345    /// Whether `TCP_NODELAY` will be set for mysql connection.
346    pub fn get_tcp_nodelay(&self) -> bool {
347        self.0.tcp_nodelay
348    }
349
350    /// TCP keep alive time for mysql connection.
351    pub fn get_tcp_keepalive_time_ms(&self) -> Option<u32> {
352        self.0.tcp_keepalive_time
353    }
354
355    /// TCP keep alive interval between subsequent probes for mysql connection.
356    #[cfg(any(target_os = "linux", target_os = "macos",))]
357    pub fn get_tcp_keepalive_probe_interval_secs(&self) -> Option<u32> {
358        self.0.tcp_keepalive_probe_interval_secs
359    }
360
361    /// TCP keep alive probe count for mysql connection.
362    #[cfg(any(target_os = "linux", target_os = "macos",))]
363    pub fn get_tcp_keepalive_probe_count(&self) -> Option<u32> {
364        self.0.tcp_keepalive_probe_count
365    }
366
367    /// TCP_USER_TIMEOUT time for mysql connection.
368    #[cfg(target_os = "linux")]
369    pub fn get_tcp_user_timeout_ms(&self) -> Option<u32> {
370        self.0.tcp_user_timeout
371    }
372
373    /// Callback to handle requests for local files.
374    pub fn get_local_infile_handler(&self) -> Option<&LocalInfileHandler> {
375        self.0.local_infile_handler.as_ref()
376    }
377
378    /// Tcp connect timeout (defaults to `None`).
379    pub fn get_tcp_connect_timeout(&self) -> Option<Duration> {
380        self.0.tcp_connect_timeout
381    }
382
383    /// Bind address for a client (defaults to `None`).
384    ///
385    /// Use carefully. Will probably make pool unusable because of *address already in use*
386    /// errors.
387    pub fn bind_address(&self) -> Option<&SocketAddr> {
388        self.0.bind_address.as_ref()
389    }
390
391    /// Number of prepared statements cached on the client side (per connection).
392    /// Defaults to [`DEFAULT_STMT_CACHE_SIZE`].
393    ///
394    /// Can be defined using `stmt_cache_size` connection url parameter.
395    pub fn get_stmt_cache_size(&self) -> usize {
396        self.0.stmt_cache_size
397    }
398
399    /// If not `None`, then client will ask for compression if server supports it
400    /// (defaults to `None`).
401    ///
402    /// Can be defined using `compress` connection url parameter with values:
403    /// * `true` - library defined default compression level;
404    /// * `fast` - library defined fast compression level;
405    /// * `best` - library defined best compression level;
406    /// * `0`, `1`, ..., `9` - explicitly defined compression level where `0` stands for
407    ///   "no compression";
408    ///
409    /// Note that compression level defined here will affect only outgoing packets.
410    pub fn get_compress(&self) -> Option<crate::Compression> {
411        self.0.compress
412    }
413
414    /// Additional client capabilities to set (defaults to empty).
415    ///
416    /// This value will be OR'ed with other client capabilities during connection initialisation.
417    ///
418    /// ### Note
419    ///
420    /// It is a good way to set something like `CLIENT_FOUND_ROWS` but you should note that it
421    /// won't let you to interfere with capabilities managed by other options (like
422    /// `CLIENT_SSL` or `CLIENT_COMPRESS`). Also note that some capabilities are reserved,
423    /// pointless or may broke the connection, so this option should be used with caution.
424    pub fn get_additional_capabilities(&self) -> CapabilityFlags {
425        self.0.additional_capabilities
426    }
427
428    /// Connect attributes
429    ///
430    /// This value is sent to the server as custom name-value attributes.
431    /// You can see them from performance_schema tables: [`session_account_connect_attrs`
432    /// and `session_connect_attrs`][attr_tables] when all of the following conditions
433    /// are met.
434    ///
435    /// * The server is MySQL 5.6 or later, or MariaDB 10.0 or later.
436    /// * [`performance_schema`] is on.
437    /// * [`performance_schema_session_connect_attrs_size`] is -1 or big enough
438    ///   to store specified attributes.
439    ///
440    /// ### Note
441    ///
442    /// Attribute names that begin with an underscore (`_`) are not set by
443    /// application programs because they are reserved for internal use.
444    ///
445    /// The following attributes are sent in addition to ones set by programs.
446    ///
447    /// name            | value
448    /// ----------------|--------------------------
449    /// _client_name    | The client library name (`rust-mysql-simple`)
450    /// _client_version | The client library version
451    /// _os             | The operation system (`target_os` cfg feature)
452    /// _pid            | The client process ID
453    /// _platform       | The machine platform (`target_arch` cfg feature)
454    /// program_name    | The first element of `std::env::args` if program_name isn't set by programs.
455    ///
456    /// [attr_tables]: https://dev.mysql.com/doc/refman/en/performance-schema-connection-attribute-tables.html
457    /// [`performance_schema`]: https://dev.mysql.com/doc/refman/8.0/en/performance-schema-system-variables.html#sysvar_performance_schema
458    /// [`performance_schema_session_connect_attrs_size`]: https://dev.mysql.com/doc/refman/en/performance-schema-system-variables.html#sysvar_performance_schema_session_connect_attrs_size
459    ///
460    pub fn get_connect_attrs(&self) -> &HashMap<String, String> {
461        &self.0.connect_attrs
462    }
463
464    /// Disables `mysql_old_password` plugin (defaults to `true`).
465    ///
466    /// Available via `secure_auth` connection url parameter.
467    pub fn get_secure_auth(&self) -> bool {
468        self.0.secure_auth
469    }
470}
471
472/// Provides a way to build [`Opts`](struct.Opts.html).
473///
474/// ```ignore
475/// let mut ssl_opts = SslOpts::default();
476/// ssl_opts = ssl_opts.with_pkcs12_path(Some(Path::new("/foo/cert.p12")))
477///         .with_root_ca_path(Some(Path::new("/foo/root_ca.der")));
478///
479/// // You can create new default builder
480/// let mut builder = OptsBuilder::new();
481/// builder = builder.ip_or_hostname(Some("foo"))
482///        .db_name(Some("bar"))
483///        .ssl_opts(Some(ssl_opts));
484///
485/// // Or use existing T: Into<Opts>
486/// let builder = OptsBuilder::from_opts(existing_opts)
487///        .ip_or_hostname(Some("foo"))
488///        .db_name(Some("bar"));
489/// ```
490///
491/// ## Connection URL
492///
493/// `Opts` also could be constructed using connection URL. See docs on `OptsBuilder`'s methods for
494/// the list of options available via URL.
495///
496/// Example:
497///
498/// ```ignore
499/// let connection_opts = mysql::Opts::from_url("mysql://root:password@localhost:3307/mysql?prefer_socket=false").unwrap();
500/// let pool = mysql::Pool::new(connection_opts).unwrap();
501/// ```
502#[derive(Debug, Clone, PartialEq)]
503pub struct OptsBuilder {
504    opts: Opts,
505}
506
507impl OptsBuilder {
508    pub fn new() -> Self {
509        OptsBuilder::default()
510    }
511
512    pub fn from_opts<T: Into<Opts>>(opts: T) -> Self {
513        OptsBuilder { opts: opts.into() }
514    }
515
516    /// Use a HashMap for creating an OptsBuilder instance:
517    /// ```ignore
518    /// OptsBuilder::new().from_hash_map(client);
519    /// ```
520    /// `HashMap` key,value pairs:
521    /// - user = Username
522    /// - password = Password
523    /// - host = Host name or ip address
524    /// - port = Port, default is 3306
525    /// - socket = Unix socket or pipe name(on windows) defaults to `None`
526    /// - db_name = Database name (defaults to `None`).
527    /// - prefer_socket = Prefer socket connection (defaults to `true`)
528    /// - tcp_keepalive_time_ms = TCP keep alive time for mysql connection (defaults to `None`)
529    /// - tcp_keepalive_probe_interval_secs = TCP keep alive interval between probes for mysql connection (defaults to `None`)
530    /// - tcp_keepalive_probe_count = TCP keep alive probe count for mysql connection (defaults to `None`)
531    /// - tcp_user_timeout_ms = TCP_USER_TIMEOUT time for mysql connection (defaults to `None`)
532    /// - compress = Compression level(defaults to `None`)
533    /// - tcp_connect_timeout_ms = Tcp connect timeout (defaults to `None`)
534    /// - stmt_cache_size = Number of prepared statements cached on the client side (per connection)
535    /// - secure_auth = Disable `mysql_old_password` auth plugin
536    ///
537    /// Login .cnf file parsing lib <https://github.com/rjcortese/myloginrs> returns a HashMap for client configs
538    ///
539    /// **Note:** You do **not** have to use myloginrs lib.
540    pub fn from_hash_map(mut self, client: &HashMap<String, String>) -> Result<Self, UrlError> {
541        for (key, value) in client.iter() {
542            match key.as_str() {
543                "user" => self.opts.0.user = Some(value.to_string()),
544                "password" => self.opts.0.pass = Some(value.to_string()),
545                "host" => {
546                    let host = url::Host::parse(value)
547                        .unwrap_or_else(|_| url::Host::Domain(value.to_owned()));
548                    self.opts.0.ip_or_hostname = host;
549                }
550                "port" => match value.parse::<u16>() {
551                    Ok(parsed) => self.opts.0.tcp_port = parsed,
552                    Err(_) => {
553                        return Err(UrlError::InvalidValue(key.to_string(), value.to_string()))
554                    }
555                },
556                "socket" => self.opts.0.socket = Some(value.to_string()),
557                "db_name" => self.opts.0.db_name = Some(value.to_string()),
558                "prefer_socket" => {
559                    //default to true like standard opts builder method
560                    match value.parse::<bool>() {
561                        Ok(parsed) => self.opts.0.prefer_socket = parsed,
562                        Err(_) => {
563                            return Err(UrlError::InvalidValue(key.to_string(), value.to_string()))
564                        }
565                    }
566                }
567                "secure_auth" => match value.parse::<bool>() {
568                    Ok(parsed) => self.opts.0.secure_auth = parsed,
569                    Err(_) => {
570                        return Err(UrlError::InvalidValue(key.to_string(), value.to_string()))
571                    }
572                },
573                "tcp_keepalive_time_ms" => {
574                    //if cannot parse, default to none
575                    self.opts.0.tcp_keepalive_time = match value.parse::<u32>() {
576                        Ok(val) => Some(val),
577                        _ => {
578                            return Err(UrlError::InvalidValue(key.to_string(), value.to_string()))
579                        }
580                    }
581                }
582                #[cfg(any(target_os = "linux", target_os = "macos",))]
583                "tcp_keepalive_probe_interval_secs" => {
584                    //if cannot parse, default to none
585                    self.opts.0.tcp_keepalive_probe_interval_secs = match value.parse::<u32>() {
586                        Ok(val) => Some(val),
587                        _ => {
588                            return Err(UrlError::InvalidValue(key.to_string(), value.to_string()))
589                        }
590                    }
591                }
592                #[cfg(any(target_os = "linux", target_os = "macos",))]
593                "tcp_keepalive_probe_count" => {
594                    //if cannot parse, default to none
595                    self.opts.0.tcp_keepalive_probe_count = match value.parse::<u32>() {
596                        Ok(val) => Some(val),
597                        _ => {
598                            return Err(UrlError::InvalidValue(key.to_string(), value.to_string()))
599                        }
600                    }
601                }
602                #[cfg(target_os = "linux")]
603                "tcp_user_timeout_ms" => {
604                    self.opts.0.tcp_user_timeout = match value.parse::<u32>() {
605                        Ok(val) => Some(val),
606                        _ => {
607                            return Err(UrlError::InvalidValue(key.to_string(), value.to_string()))
608                        }
609                    }
610                }
611                "compress" => match value.parse::<u32>() {
612                    Ok(val) => self.opts.0.compress = Some(Compression::new(val)),
613                    Err(_) => {
614                        //not an int
615                        match value.as_str() {
616                            "fast" => self.opts.0.compress = Some(Compression::fast()),
617                            "best" => self.opts.0.compress = Some(Compression::best()),
618                            "true" => self.opts.0.compress = Some(Compression::default()),
619                            _ => {
620                                return Err(UrlError::InvalidValue(
621                                    key.to_string(),
622                                    value.to_string(),
623                                )); //should not go below this due to catch all
624                            }
625                        }
626                    }
627                },
628                "tcp_connect_timeout_ms" => {
629                    self.opts.0.tcp_connect_timeout = match value.parse::<u64>() {
630                        Ok(val) => Some(Duration::from_millis(val)),
631                        _ => {
632                            return Err(UrlError::InvalidValue(key.to_string(), value.to_string()))
633                        }
634                    }
635                }
636                "stmt_cache_size" => match value.parse::<usize>() {
637                    Ok(parsed) => self.opts.0.stmt_cache_size = parsed,
638                    Err(_) => {
639                        return Err(UrlError::InvalidValue(key.to_string(), value.to_string()))
640                    }
641                },
642                _ => {
643                    //throw an error if there is an unrecognized param
644                    return Err(UrlError::UnknownParameter(key.to_string()));
645                }
646            }
647        }
648        Ok(self)
649    }
650
651    /// Address of mysql server (defaults to `127.0.0.1`). Hostnames should also work.
652    ///
653    /// **Note:** IPv6 addresses must be given in square brackets, e.g. `[::1]`.
654    pub fn ip_or_hostname<T: Into<String>>(mut self, ip_or_hostname: Option<T>) -> Self {
655        let new = ip_or_hostname
656            .map(Into::into)
657            .unwrap_or_else(|| "127.0.0.1".into());
658        self.opts.0.ip_or_hostname =
659            url::Host::parse(&new).unwrap_or_else(|_| url::Host::Domain(new.to_owned()));
660        self
661    }
662
663    /// TCP port of mysql server (defaults to `3306`).
664    pub fn tcp_port(mut self, tcp_port: u16) -> Self {
665        self.opts.0.tcp_port = tcp_port;
666        self
667    }
668
669    /// Socket path on unix or pipe name on windows (defaults to `None`).
670    ///
671    /// Can be defined using `socket` connection url parameter.
672    pub fn socket<T: Into<String>>(mut self, socket: Option<T>) -> Self {
673        self.opts.0.socket = socket.map(Into::into);
674        self
675    }
676
677    /// User (defaults to `None`).
678    pub fn user<T: Into<String>>(mut self, user: Option<T>) -> Self {
679        self.opts.0.user = user.map(Into::into);
680        self
681    }
682
683    /// Password (defaults to `None`).
684    pub fn pass<T: Into<String>>(mut self, pass: Option<T>) -> Self {
685        self.opts.0.pass = pass.map(Into::into);
686        self
687    }
688
689    /// Database name (defaults to `None`).
690    pub fn db_name<T: Into<String>>(mut self, db_name: Option<T>) -> Self {
691        self.opts.0.db_name = db_name.map(Into::into);
692        self
693    }
694
695    /// The timeout for each attempt to read from the server (defaults to `None`).
696    ///
697    /// Note that named pipe connection will ignore duration's `nanos`, and also note that
698    /// it is an error to pass the zero `Duration` to this method.
699    pub fn read_timeout(mut self, read_timeout: Option<Duration>) -> Self {
700        self.opts.0.read_timeout = read_timeout;
701        self
702    }
703
704    /// The timeout for each attempt to write to the server (defaults to `None`).
705    ///
706    /// Note that named pipe connection will ignore duration's `nanos`, and also note that
707    /// it is likely error to pass the zero `Duration` to this method.
708    pub fn write_timeout(mut self, write_timeout: Option<Duration>) -> Self {
709        self.opts.0.write_timeout = write_timeout;
710        self
711    }
712
713    /// TCP keep alive time for mysql connection (defaults to `None`). Available as
714    /// `tcp_keepalive_time_ms` url parameter.
715    ///
716    /// Can be defined using `tcp_keepalive_time_ms` connection url parameter.
717    pub fn tcp_keepalive_time_ms(mut self, tcp_keepalive_time_ms: Option<u32>) -> Self {
718        self.opts.0.tcp_keepalive_time = tcp_keepalive_time_ms;
719        self
720    }
721
722    /// TCP keep alive interval between probes for mysql connection (defaults to `None`). Available as
723    /// `tcp_keepalive_probe_interval_secs` url parameter.
724    ///
725    /// Can be defined using `tcp_keepalive_probe_interval_secs` connection url parameter.
726    #[cfg(any(target_os = "linux", target_os = "macos",))]
727    pub fn tcp_keepalive_probe_interval_secs(
728        mut self,
729        tcp_keepalive_probe_interval_secs: Option<u32>,
730    ) -> Self {
731        self.opts.0.tcp_keepalive_probe_interval_secs = tcp_keepalive_probe_interval_secs;
732        self
733    }
734
735    /// TCP keep alive probe count for mysql connection (defaults to `None`). Available as
736    /// `tcp_keepalive_probe_count` url parameter.
737    ///
738    /// Can be defined using `tcp_keepalive_probe_count` connection url parameter.
739    #[cfg(any(target_os = "linux", target_os = "macos",))]
740    pub fn tcp_keepalive_probe_count(mut self, tcp_keepalive_probe_count: Option<u32>) -> Self {
741        self.opts.0.tcp_keepalive_probe_count = tcp_keepalive_probe_count;
742        self
743    }
744
745    /// TCP_USER_TIMEOUT for mysql connection (defaults to `None`). Available as
746    /// `tcp_user_timeout_ms` url parameter.
747    ///
748    /// Can be defined using `tcp_user_timeout_ms` connection url parameter.
749    #[cfg(target_os = "linux")]
750    pub fn tcp_user_timeout_ms(mut self, tcp_user_timeout_ms: Option<u32>) -> Self {
751        self.opts.0.tcp_user_timeout = tcp_user_timeout_ms;
752        self
753    }
754
755    /// Set the `TCP_NODELAY` option for the mysql connection (defaults to `true`).
756    ///
757    /// Setting this option to false re-enables Nagle's algorithm, which can cause unusually high
758    /// latency (~40ms) but may increase maximum throughput. See #132.
759    pub fn tcp_nodelay(mut self, nodelay: bool) -> Self {
760        self.opts.0.tcp_nodelay = nodelay;
761        self
762    }
763
764    /// Prefer socket connection (defaults to `true`). Available as `prefer_socket` url parameter
765    /// with value `true` or `false`.
766    ///
767    /// Will reconnect via socket (on named pipe on windows) after TCP connection
768    /// to `127.0.0.1` if `true`.
769    ///
770    /// Will fall back to TCP on error. Use `socket` option to enforce socket connection.
771    ///
772    /// Can be defined using `prefer_socket` connection url parameter.
773    pub fn prefer_socket(mut self, prefer_socket: bool) -> Self {
774        self.opts.0.prefer_socket = prefer_socket;
775        self
776    }
777
778    /// Commands to execute on each new database connection.
779    pub fn init<T: Into<String>>(mut self, init: Vec<T>) -> Self {
780        self.opts.0.init = init.into_iter().map(Into::into).collect();
781        self
782    }
783
784    /// Driver will require SSL connection if this option isn't `None` (default to `None`).
785    pub fn ssl_opts<T: Into<Option<SslOpts>>>(mut self, ssl_opts: T) -> Self {
786        self.opts.0.ssl_opts = ssl_opts.into();
787        self
788    }
789
790    /// Callback to handle requests for local files. These are
791    /// caused by using `LOAD DATA LOCAL INFILE` queries. The
792    /// callback is passed the filename, and a `Write`able object
793    /// to receive the contents of that file.
794    /// If unset, the default callback will read files relative to
795    /// the current directory.
796    pub fn local_infile_handler(mut self, handler: Option<LocalInfileHandler>) -> Self {
797        self.opts.0.local_infile_handler = handler;
798        self
799    }
800
801    /// Tcp connect timeout (defaults to `None`). Available as `tcp_connect_timeout_ms`
802    /// url parameter.
803    ///
804    /// Can be defined using `tcp_connect_timeout_ms` connection url parameter.
805    pub fn tcp_connect_timeout(mut self, timeout: Option<Duration>) -> Self {
806        self.opts.0.tcp_connect_timeout = timeout;
807        self
808    }
809
810    /// Bind address for a client (defaults to `None`).
811    ///
812    /// Use carefully. Will probably make pool unusable because of *address already in use*
813    /// errors.
814    pub fn bind_address<T>(mut self, bind_address: Option<T>) -> Self
815    where
816        T: Into<SocketAddr>,
817    {
818        self.opts.0.bind_address = bind_address.map(Into::into);
819        self
820    }
821
822    /// Number of prepared statements cached on the client side (per connection).
823    /// Defaults to [`DEFAULT_STMT_CACHE_SIZE`].
824    ///
825    /// Can be defined using `stmt_cache_size` connection url parameter.
826    ///
827    /// Call with `None` to reset to default.
828    pub fn stmt_cache_size<T>(mut self, cache_size: T) -> Self
829    where
830        T: Into<Option<usize>>,
831    {
832        self.opts.0.stmt_cache_size = cache_size.into().unwrap_or(128);
833        self
834    }
835
836    /// If not `None`, then client will ask for compression if server supports it
837    /// (defaults to `None`).
838    ///
839    /// Can be defined using `compress` connection url parameter with values:
840    /// * `true` - library defined default compression level;
841    /// * `fast` - library defined fast compression level;
842    /// * `best` - library defined best compression level;
843    /// * `0`, `1`, ..., `9` - explicitly defined compression level where `0` stands for
844    ///   "no compression";
845    ///
846    /// Note that compression level defined here will affect only outgoing packets.
847    pub fn compress(mut self, compress: Option<crate::Compression>) -> Self {
848        self.opts.0.compress = compress;
849        self
850    }
851
852    /// Additional client capabilities to set (defaults to empty).
853    ///
854    /// This value will be OR'ed with other client capabilities during connection initialisation.
855    ///
856    /// ### Note
857    ///
858    /// It is a good way to set something like `CLIENT_FOUND_ROWS` but you should note that it
859    /// won't let you to interfere with capabilities managed by other options (like
860    /// `CLIENT_SSL` or `CLIENT_COMPRESS`). Also note that some capabilities are reserved,
861    /// pointless or may broke the connection, so this option should be used with caution.
862    pub fn additional_capabilities(mut self, additional_capabilities: CapabilityFlags) -> Self {
863        let forbidden_flags: CapabilityFlags = CapabilityFlags::CLIENT_PROTOCOL_41
864            | CapabilityFlags::CLIENT_SSL
865            | CapabilityFlags::CLIENT_COMPRESS
866            | CapabilityFlags::CLIENT_SECURE_CONNECTION
867            | CapabilityFlags::CLIENT_LONG_PASSWORD
868            | CapabilityFlags::CLIENT_TRANSACTIONS
869            | CapabilityFlags::CLIENT_LOCAL_FILES
870            | CapabilityFlags::CLIENT_MULTI_STATEMENTS
871            | CapabilityFlags::CLIENT_MULTI_RESULTS
872            | CapabilityFlags::CLIENT_PS_MULTI_RESULTS;
873
874        self.opts.0.additional_capabilities = additional_capabilities & !forbidden_flags;
875        self
876    }
877
878    /// Connect attributes
879    ///
880    /// This value is sent to the server as custom name-value attributes.
881    /// You can see them from performance_schema tables: [`session_account_connect_attrs`
882    /// and `session_connect_attrs`][attr_tables] when all of the following conditions
883    /// are met.
884    ///
885    /// * The server is MySQL 5.6 or later, or MariaDB 10.0 or later.
886    /// * [`performance_schema`] is on.
887    /// * [`performance_schema_session_connect_attrs_size`] is -1 or big enough
888    ///   to store specified attributes.
889    ///
890    /// ### Note
891    ///
892    /// Attribute names that begin with an underscore (`_`) are not set by
893    /// application programs because they are reserved for internal use.
894    ///
895    /// The following attributes are sent in addition to ones set by programs.
896    ///
897    /// name            | value
898    /// ----------------|--------------------------
899    /// _client_name    | The client library name (`rust-mysql-simple`)
900    /// _client_version | The client library version
901    /// _os             | The operation system (`target_os` cfg feature)
902    /// _pid            | The client process ID
903    /// _platform       | The machine platform (`target_arch` cfg feature)
904    /// program_name    | The first element of `std::env::args` if program_name isn't set by programs.
905    ///
906    /// [attr_tables]: https://dev.mysql.com/doc/refman/en/performance-schema-connection-attribute-tables.html
907    /// [`performance_schema`]: https://dev.mysql.com/doc/refman/8.0/en/performance-schema-system-variables.html#sysvar_performance_schema
908    /// [`performance_schema_session_connect_attrs_size`]: https://dev.mysql.com/doc/refman/en/performance-schema-system-variables.html#sysvar_performance_schema_session_connect_attrs_size
909    ///
910    pub fn connect_attrs<T1: Into<String> + Eq + Hash, T2: Into<String>>(
911        mut self,
912        connect_attrs: HashMap<T1, T2>,
913    ) -> Self {
914        self.opts.0.connect_attrs = HashMap::with_capacity(connect_attrs.len());
915        for (name, value) in connect_attrs {
916            let name = name.into();
917            if !name.starts_with('_') {
918                self.opts.0.connect_attrs.insert(name, value.into());
919            }
920        }
921        self
922    }
923
924    /// Disables `mysql_old_password` plugin (defaults to `true`).
925    ///
926    /// Available via `secure_auth` connection url parameter.
927    pub fn secure_auth(mut self, secure_auth: bool) -> Self {
928        self.opts.0.secure_auth = secure_auth;
929        self
930    }
931}
932
933impl From<OptsBuilder> for Opts {
934    fn from(builder: OptsBuilder) -> Opts {
935        builder.opts
936    }
937}
938
939impl Default for OptsBuilder {
940    fn default() -> OptsBuilder {
941        OptsBuilder {
942            opts: Opts::default(),
943        }
944    }
945}
946
947fn get_opts_user_from_url(url: &Url) -> Option<String> {
948    let user = url.username();
949    if user != "" {
950        Some(
951            percent_decode(user.as_ref())
952                .decode_utf8_lossy()
953                .into_owned(),
954        )
955    } else {
956        None
957    }
958}
959
960fn get_opts_pass_from_url(url: &Url) -> Option<String> {
961    if let Some(pass) = url.password() {
962        Some(
963            percent_decode(pass.as_ref())
964                .decode_utf8_lossy()
965                .into_owned(),
966        )
967    } else {
968        None
969    }
970}
971
972fn get_opts_db_name_from_url(url: &Url) -> Option<String> {
973    if let Some(mut segments) = url.path_segments() {
974        segments
975            .next()
976            .filter(|&db_name| !db_name.is_empty())
977            .map(|db_name| {
978                percent_decode(db_name.as_ref())
979                    .decode_utf8_lossy()
980                    .into_owned()
981            })
982    } else {
983        None
984    }
985}
986
987fn from_url_basic(url_str: &str) -> Result<(Opts, Vec<(String, String)>), UrlError> {
988    let url = Url::parse(url_str)?;
989    if url.scheme() != "mysql" {
990        return Err(UrlError::UnsupportedScheme(url.scheme().to_string()));
991    }
992    if url.cannot_be_a_base() {
993        return Err(UrlError::BadUrl);
994    }
995    let user = get_opts_user_from_url(&url);
996    let pass = get_opts_pass_from_url(&url);
997    let ip_or_hostname = url
998        .host()
999        .ok_or(UrlError::BadUrl)
1000        .and_then(|host| url::Host::parse(&host.to_string()).map_err(|_| UrlError::BadUrl))?;
1001    let tcp_port = url.port().unwrap_or(3306);
1002    let db_name = get_opts_db_name_from_url(&url);
1003
1004    let query_pairs = url.query_pairs().into_owned().collect();
1005    let opts = Opts(Box::new(InnerOpts {
1006        user,
1007        pass,
1008        ip_or_hostname,
1009        tcp_port,
1010        db_name,
1011        ..InnerOpts::default()
1012    }));
1013
1014    Ok((opts, query_pairs))
1015}
1016
1017fn from_url(url: &str) -> Result<Opts, UrlError> {
1018    let (opts, query_pairs) = from_url_basic(url)?;
1019    let hash_map = query_pairs.into_iter().collect::<HashMap<String, String>>();
1020    OptsBuilder::from_opts(opts)
1021        .from_hash_map(&hash_map)
1022        .map(Into::into)
1023}
1024
1025#[cfg(test)]
1026mod test {
1027    use mysql_common::proto::codec::Compression;
1028    use std::time::Duration;
1029
1030    use super::{InnerOpts, Opts, OptsBuilder};
1031
1032    #[allow(dead_code)]
1033    fn assert_conn_from_url_opts_optsbuilder(url: &str, opts: Opts, opts_builder: OptsBuilder) {
1034        crate::Conn::new(url).unwrap();
1035        crate::Conn::new(opts.clone()).unwrap();
1036        crate::Conn::new(opts_builder.clone()).unwrap();
1037        crate::Pool::new(url).unwrap();
1038        crate::Pool::new(opts).unwrap();
1039        crate::Pool::new(opts_builder).unwrap();
1040    }
1041
1042    #[test]
1043    fn should_report_empty_url_database_as_none() {
1044        let opt = Opts::from_url("mysql://localhost/").unwrap();
1045        assert_eq!(opt.get_db_name(), None);
1046    }
1047
1048    #[test]
1049    fn should_convert_url_into_opts() {
1050        #[cfg(any(target_os = "linux", target_os = "macos",))]
1051        let tcp_keepalive_probe_interval_secs = "&tcp_keepalive_probe_interval_secs=8";
1052        #[cfg(not(any(target_os = "linux", target_os = "macos",)))]
1053        let tcp_keepalive_probe_interval_secs = "";
1054
1055        #[cfg(any(target_os = "linux", target_os = "macos",))]
1056        let tcp_keepalive_probe_count = "&tcp_keepalive_probe_count=5";
1057        #[cfg(not(any(target_os = "linux", target_os = "macos",)))]
1058        let tcp_keepalive_probe_count = "";
1059
1060        #[cfg(target_os = "linux")]
1061        let tcp_user_timeout = "&tcp_user_timeout_ms=6000";
1062        #[cfg(not(target_os = "linux"))]
1063        let tcp_user_timeout = "";
1064
1065        let opts = format!(
1066            "mysql://us%20r:p%20w@localhost:3308/db%2dname?prefer_socket=false&tcp_keepalive_time_ms=5000{}{}{}&socket=%2Ftmp%2Fmysql.sock&compress=8",
1067            tcp_keepalive_probe_interval_secs,
1068            tcp_keepalive_probe_count,
1069            tcp_user_timeout,
1070        );
1071        assert_eq!(
1072            Opts(Box::new(InnerOpts {
1073                user: Some("us r".to_string()),
1074                pass: Some("p w".to_string()),
1075                ip_or_hostname: url::Host::Domain("localhost".to_string()),
1076                tcp_port: 3308,
1077                db_name: Some("db-name".to_string()),
1078                prefer_socket: false,
1079                tcp_keepalive_time: Some(5000),
1080                #[cfg(any(target_os = "linux", target_os = "macos",))]
1081                tcp_keepalive_probe_interval_secs: Some(8),
1082                #[cfg(any(target_os = "linux", target_os = "macos",))]
1083                tcp_keepalive_probe_count: Some(5),
1084                #[cfg(target_os = "linux")]
1085                tcp_user_timeout: Some(6000),
1086                socket: Some("/tmp/mysql.sock".into()),
1087                compress: Some(Compression::new(8)),
1088                ..InnerOpts::default()
1089            })),
1090            Opts::from_url(&opts).unwrap(),
1091        );
1092    }
1093
1094    #[test]
1095    #[should_panic]
1096    fn should_panic_on_invalid_url() {
1097        let opts = "42";
1098        Opts::from_url(opts).unwrap();
1099    }
1100
1101    #[test]
1102    #[should_panic]
1103    fn should_panic_on_invalid_scheme() {
1104        let opts = "postgres://localhost";
1105        Opts::from_url(opts).unwrap();
1106    }
1107
1108    #[test]
1109    #[should_panic]
1110    fn should_panic_on_unknown_query_param() {
1111        let opts = "mysql://localhost/foo?bar=baz";
1112        Opts::from_url(opts).unwrap();
1113    }
1114
1115    #[test]
1116    fn should_read_hashmap_into_opts() {
1117        use crate::OptsBuilder;
1118        macro_rules!  map(
1119            { $($key:expr => $value:expr), + }=> {
1120                {
1121                    let mut h = std::collections::HashMap::new();
1122                    $(
1123                        h.insert($key, $value);
1124                    )+
1125                    h
1126                }
1127            };
1128        );
1129
1130        let mut cnf_map = map! {
1131            "user".to_string() => "test".to_string(),
1132            "password".to_string() => "password".to_string(),
1133            "host".to_string() => "127.0.0.1".to_string(),
1134            "port".to_string() => "8080".to_string(),
1135            "db_name".to_string() => "test_db".to_string(),
1136            "prefer_socket".to_string() => "false".to_string(),
1137            "tcp_keepalive_time_ms".to_string() => "5000".to_string(),
1138            "compress".to_string() => "best".to_string(),
1139            "tcp_connect_timeout_ms".to_string() => "1000".to_string(),
1140            "stmt_cache_size".to_string() => "33".to_string()
1141        };
1142        #[cfg(any(target_os = "linux", target_os = "macos",))]
1143        cnf_map.insert(
1144            "tcp_keepalive_probe_interval_secs".to_string(),
1145            "8".to_string(),
1146        );
1147        #[cfg(any(target_os = "linux", target_os = "macos",))]
1148        cnf_map.insert("tcp_keepalive_probe_count".to_string(), "5".to_string());
1149
1150        let parsed_opts = OptsBuilder::new().from_hash_map(&cnf_map).unwrap();
1151
1152        assert_eq!(parsed_opts.opts.get_user(), Some("test"));
1153        assert_eq!(parsed_opts.opts.get_pass(), Some("password"));
1154        assert_eq!(parsed_opts.opts.get_ip_or_hostname(), "127.0.0.1");
1155        assert_eq!(parsed_opts.opts.get_tcp_port(), 8080);
1156        assert_eq!(parsed_opts.opts.get_db_name(), Some("test_db"));
1157        assert_eq!(parsed_opts.opts.get_prefer_socket(), false);
1158        assert_eq!(parsed_opts.opts.get_tcp_keepalive_time_ms(), Some(5000));
1159        #[cfg(any(target_os = "linux", target_os = "macos",))]
1160        assert_eq!(
1161            parsed_opts.opts.get_tcp_keepalive_probe_interval_secs(),
1162            Some(8)
1163        );
1164        #[cfg(any(target_os = "linux", target_os = "macos",))]
1165        assert_eq!(parsed_opts.opts.get_tcp_keepalive_probe_count(), Some(5));
1166        assert_eq!(
1167            parsed_opts.opts.get_compress(),
1168            Some(crate::Compression::best())
1169        );
1170        assert_eq!(
1171            parsed_opts.opts.get_tcp_connect_timeout(),
1172            Some(Duration::from_millis(1000))
1173        );
1174        assert_eq!(parsed_opts.opts.get_stmt_cache_size(), 33);
1175    }
1176
1177    #[test]
1178    fn should_have_url_err() {
1179        use crate::OptsBuilder;
1180        use crate::UrlError;
1181        macro_rules!  map(
1182            { $($key:expr => $value:expr), + }=> {
1183                {
1184                    let mut h = std::collections::HashMap::new();
1185                    $(
1186                        h.insert($key, $value);
1187                    )+
1188                    h
1189                }
1190            };
1191        );
1192
1193        let cnf_map = map! {
1194            "user".to_string() => "test".to_string(),
1195            "password".to_string() => "password".to_string(),
1196            "host".to_string() => "127.0.0.1".to_string(),
1197            "port".to_string() => "NOTAPORT".to_string(),
1198            "db_name".to_string() => "test_db".to_string(),
1199            "prefer_socket".to_string() => "false".to_string(),
1200            "tcp_keepalive_time_ms".to_string() => "5000".to_string(),
1201            "compress".to_string() => "best".to_string(),
1202            "tcp_connect_timeout_ms".to_string() => "1000".to_string(),
1203            "stmt_cache_size".to_string() => "33".to_string()
1204        };
1205
1206        let parsed = OptsBuilder::new().from_hash_map(&cnf_map);
1207        assert_eq!(
1208            parsed,
1209            Err(UrlError::InvalidValue(
1210                "port".to_string(),
1211                "NOTAPORT".to_string()
1212            ))
1213        );
1214    }
1215}