use crate::consts::CapabilityFlags;
use std::collections::HashMap;
use std::hash::Hash;
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
#[cfg(all(feature = "ssl", not(target_os = "windows")))]
use std::path;
use std::str::FromStr;
use std::time::Duration;
use super::super::error::UrlError;
use super::LocalInfileHandler;
use url::percent_encoding::percent_decode;
use url::Url;
#[cfg(all(feature = "ssl", target_os = "macos"))]
pub type SslOpts = Option<Option<(path::PathBuf, String, Vec<path::PathBuf>)>>;
#[cfg(all(feature = "ssl", not(target_os = "macos"), unix))]
pub type SslOpts = Option<(path::PathBuf, Option<(path::PathBuf, path::PathBuf)>)>;
#[cfg(all(feature = "ssl", target_os = "windows"))]
pub type SslOpts = Option<()>;
#[cfg(not(feature = "ssl"))]
pub type SslOpts = Option<()>;
#[derive(Clone, Eq, PartialEq, Debug)]
pub struct Opts {
ip_or_hostname: Option<String>,
tcp_port: u16,
socket: Option<String>,
user: Option<String>,
pass: Option<String>,
db_name: Option<String>,
read_timeout: Option<Duration>,
write_timeout: Option<Duration>,
prefer_socket: bool,
tcp_nodelay: bool,
tcp_keepalive_time: Option<u32>,
init: Vec<String>,
verify_peer: bool,
ssl_opts: SslOpts,
local_infile_handler: Option<LocalInfileHandler>,
tcp_connect_timeout: Option<Duration>,
bind_address: Option<SocketAddr>,
stmt_cache_size: usize,
compress: bool,
additional_capabilities: CapabilityFlags,
connect_attrs: HashMap<String, String>,
#[cfg(test)]
pub injected_socket: Option<String>,
}
impl Opts {
#[doc(hidden)]
pub fn addr_is_loopback(&self) -> bool {
if self.ip_or_hostname.is_some() {
let v4addr: Option<Ipv4Addr> =
FromStr::from_str(self.ip_or_hostname.as_ref().unwrap().as_ref()).ok();
let v6addr: Option<Ipv6Addr> =
FromStr::from_str(self.ip_or_hostname.as_ref().unwrap().as_ref()).ok();
if let Some(addr) = v4addr {
addr.is_loopback()
} else if let Some(addr) = v6addr {
addr.is_loopback()
} else if self.ip_or_hostname.as_ref().unwrap() == "localhost" {
true
} else {
false
}
} else {
false
}
}
pub fn from_url(url: &str) -> Result<Opts, UrlError> {
from_url(url)
}
pub fn get_ip_or_hostname(&self) -> Option<&str> {
self.ip_or_hostname.as_ref().map(|x| &**x)
}
pub fn get_tcp_port(&self) -> u16 {
self.tcp_port
}
pub fn get_socket(&self) -> Option<&str> {
self.socket.as_ref().map(|x| &**x)
}
pub fn get_user(&self) -> Option<&str> {
self.user.as_ref().map(|x| &**x)
}
pub fn get_pass(&self) -> Option<&str> {
self.pass.as_ref().map(|x| &**x)
}
pub fn get_db_name(&self) -> Option<&str> {
self.db_name.as_ref().map(|x| &**x)
}
pub fn get_read_timeout(&self) -> Option<&Duration> {
self.read_timeout.as_ref()
}
pub fn get_write_timeout(&self) -> Option<&Duration> {
self.write_timeout.as_ref()
}
pub fn get_prefer_socket(&self) -> bool {
self.prefer_socket
}
pub fn get_init(&self) -> Vec<String> {
self.init.clone()
}
pub fn get_verify_peer(&self) -> bool {
self.verify_peer
}
pub fn get_ssl_opts(&self) -> &SslOpts {
&self.ssl_opts
}
fn set_prefer_socket(&mut self, val: bool) {
self.prefer_socket = val;
}
fn set_verify_peer(&mut self, val: bool) {
self.verify_peer = val;
}
pub fn get_tcp_nodelay(&self) -> bool {
self.tcp_nodelay
}
pub fn get_tcp_keepalive_time_ms(&self) -> Option<u32> {
self.tcp_keepalive_time
}
pub fn get_local_infile_handler(&self) -> Option<&LocalInfileHandler> {
self.local_infile_handler.as_ref()
}
pub fn get_tcp_connect_timeout(&self) -> Option<Duration> {
self.tcp_connect_timeout
}
pub fn bind_address(&self) -> Option<&SocketAddr> {
self.bind_address.as_ref()
}
pub fn get_stmt_cache_size(&self) -> usize {
self.stmt_cache_size
}
pub fn get_compress(&self) -> bool {
self.compress
}
pub fn get_additional_capabilities(&self) -> CapabilityFlags {
self.additional_capabilities
}
pub fn get_connect_attrs(&self) -> &HashMap<String, String> {
&self.connect_attrs
}
}
impl Default for Opts {
fn default() -> Opts {
Opts {
ip_or_hostname: Some("127.0.0.1".to_string()),
tcp_port: 3306,
socket: None,
user: None,
pass: None,
db_name: None,
read_timeout: None,
write_timeout: None,
prefer_socket: true,
init: vec![],
verify_peer: false,
ssl_opts: None,
tcp_keepalive_time: None,
tcp_nodelay: true,
local_infile_handler: None,
tcp_connect_timeout: None,
bind_address: None,
stmt_cache_size: 10,
compress: false,
additional_capabilities: CapabilityFlags::empty(),
connect_attrs: HashMap::new(),
#[cfg(test)]
injected_socket: None,
}
}
}
pub struct OptsBuilder {
opts: Opts,
}
impl OptsBuilder {
pub fn new() -> Self {
OptsBuilder::default()
}
pub fn from_opts<T: Into<Opts>>(opts: T) -> Self {
OptsBuilder { opts: opts.into() }
}
pub fn ip_or_hostname<T: Into<String>>(&mut self, ip_or_hostname: Option<T>) -> &mut Self {
self.opts.ip_or_hostname = ip_or_hostname.map(Into::into);
self
}
pub fn tcp_port(&mut self, tcp_port: u16) -> &mut Self {
self.opts.tcp_port = tcp_port;
self
}
pub fn socket<T: Into<String>>(&mut self, socket: Option<T>) -> &mut Self {
self.opts.socket = socket.map(Into::into);
self
}
pub fn user<T: Into<String>>(&mut self, user: Option<T>) -> &mut Self {
self.opts.user = user.map(Into::into);
self
}
pub fn pass<T: Into<String>>(&mut self, pass: Option<T>) -> &mut Self {
self.opts.pass = pass.map(Into::into);
self
}
pub fn db_name<T: Into<String>>(&mut self, db_name: Option<T>) -> &mut Self {
self.opts.db_name = db_name.map(Into::into);
self
}
pub fn read_timeout(&mut self, read_timeout: Option<Duration>) -> &mut Self {
self.opts.read_timeout = read_timeout;
self
}
pub fn write_timeout(&mut self, write_timeout: Option<Duration>) -> &mut Self {
self.opts.write_timeout = write_timeout;
self
}
pub fn tcp_keepalive_time_ms(&mut self, tcp_keepalive_time_ms: Option<u32>) -> &mut Self {
self.opts.tcp_keepalive_time = tcp_keepalive_time_ms;
self
}
pub fn tcp_nodelay(&mut self, nodelay: bool) -> &mut Self {
self.opts.tcp_nodelay = nodelay;
self
}
pub fn prefer_socket(&mut self, prefer_socket: bool) -> &mut Self {
self.opts.prefer_socket = prefer_socket;
self
}
pub fn init<T: Into<String>>(&mut self, init: Vec<T>) -> &mut Self {
self.opts.init = init.into_iter().map(Into::into).collect();
self
}
pub fn verify_peer(&mut self, verify_peer: bool) -> &mut Self {
self.opts.verify_peer = verify_peer;
self
}
#[cfg(all(feature = "ssl", not(target_os = "macos"), unix))]
pub fn ssl_opts<A, B, C>(&mut self, ssl_opts: Option<(A, Option<(B, C)>)>) -> &mut Self
where
A: Into<path::PathBuf>,
B: Into<path::PathBuf>,
C: Into<path::PathBuf>,
{
self.opts.ssl_opts = ssl_opts.map(|(ca_cert, rest)| {
(
ca_cert.into(),
rest.map(|(client_cert, client_key)| (client_cert.into(), client_key.into())),
)
});
self
}
#[cfg(all(feature = "ssl", target_os = "macos"))]
pub fn ssl_opts<A, B, C>(&mut self, ssl_opts: Option<Option<(A, C, Vec<B>)>>) -> &mut Self
where
A: Into<path::PathBuf>,
B: Into<path::PathBuf>,
C: Into<String>,
{
self.opts.ssl_opts = ssl_opts.map(|opts| {
opts.map(|(pkcs12_path, pass, certs)| {
(
pkcs12_path.into(),
pass.into(),
certs.into_iter().map(Into::into).collect(),
)
})
});
self
}
#[cfg(all(feature = "ssl", target_os = "windows"))]
pub fn ssl_opts<A, B, C>(&mut self, _: Option<SslOpts>) -> &mut Self {
panic!("OptsBuilder::ssl_opts is not implemented on Windows");
}
#[cfg(not(feature = "ssl"))]
pub fn ssl_opts<A, B, C>(&mut self, _: Option<SslOpts>) -> &mut Self {
panic!("OptsBuilder::ssl_opts requires `ssl` feature");
}
pub fn local_infile_handler(&mut self, handler: Option<LocalInfileHandler>) -> &mut Self {
self.opts.local_infile_handler = handler;
self
}
pub fn tcp_connect_timeout(&mut self, timeout: Option<Duration>) -> &mut Self {
self.opts.tcp_connect_timeout = timeout;
self
}
pub fn bind_address<T>(&mut self, bind_address: Option<T>) -> &mut Self
where
T: Into<SocketAddr>,
{
self.opts.bind_address = bind_address.map(Into::into);
self
}
pub fn stmt_cache_size<T>(&mut self, cache_size: T) -> &mut Self
where
T: Into<Option<usize>>,
{
self.opts.stmt_cache_size = cache_size.into().unwrap_or(10);
self
}
pub fn compress(&mut self, compress: bool) -> &mut Self {
self.opts.compress = compress;
self
}
pub fn additional_capabilities(
&mut self,
additional_capabilities: CapabilityFlags,
) -> &mut Self {
let forbidden_flags: CapabilityFlags = CapabilityFlags::CLIENT_PROTOCOL_41
| CapabilityFlags::CLIENT_SSL
| CapabilityFlags::CLIENT_COMPRESS
| CapabilityFlags::CLIENT_SECURE_CONNECTION
| CapabilityFlags::CLIENT_LONG_PASSWORD
| CapabilityFlags::CLIENT_TRANSACTIONS
| CapabilityFlags::CLIENT_LOCAL_FILES
| CapabilityFlags::CLIENT_MULTI_STATEMENTS
| CapabilityFlags::CLIENT_MULTI_RESULTS
| CapabilityFlags::CLIENT_PS_MULTI_RESULTS;
self.opts.additional_capabilities = additional_capabilities & !forbidden_flags;
self
}
pub fn connect_attrs<T1: Into<String> + Eq + Hash, T2: Into<String>>(
&mut self,
connect_attrs: HashMap<T1, T2>,
) -> &mut Self {
self.opts.connect_attrs = HashMap::with_capacity(connect_attrs.len());
for (name, value) in connect_attrs {
let name = name.into();
if !name.starts_with("_") {
self.opts.connect_attrs.insert(name, value.into());
}
}
self
}
}
impl From<OptsBuilder> for Opts {
fn from(builder: OptsBuilder) -> Opts {
builder.opts
}
}
impl Default for OptsBuilder {
fn default() -> OptsBuilder {
OptsBuilder {
opts: Opts::default(),
}
}
}
fn get_opts_user_from_url(url: &Url) -> Option<String> {
let user = url.username();
if user != "" {
Some(
percent_decode(user.as_ref())
.decode_utf8_lossy()
.into_owned(),
)
} else {
None
}
}
fn get_opts_pass_from_url(url: &Url) -> Option<String> {
if let Some(pass) = url.password() {
Some(
percent_decode(pass.as_ref())
.decode_utf8_lossy()
.into_owned(),
)
} else {
None
}
}
fn get_opts_db_name_from_url(url: &Url) -> Option<String> {
if let Some(mut segments) = url.path_segments() {
segments.next().map(|db_name| {
percent_decode(db_name.as_ref())
.decode_utf8_lossy()
.into_owned()
})
} else {
None
}
}
fn from_url_basic(url_str: &str) -> Result<(Opts, Vec<(String, String)>), UrlError> {
let url = Url::parse(url_str)?;
if url.scheme() != "mysql" {
return Err(UrlError::UnsupportedScheme(url.scheme().to_string()));
}
if url.cannot_be_a_base() || !url.has_host() {
return Err(UrlError::BadUrl);
}
let user = get_opts_user_from_url(&url);
let pass = get_opts_pass_from_url(&url);
let ip_or_hostname = url.host_str().map(String::from);
let tcp_port = url.port().unwrap_or(3306);
let db_name = get_opts_db_name_from_url(&url);
let query_pairs = url.query_pairs().into_owned().collect();
let opts = Opts {
user,
pass,
ip_or_hostname,
tcp_port,
db_name,
..Opts::default()
};
Ok((opts, query_pairs))
}
fn from_url(url: &str) -> Result<Opts, UrlError> {
let (mut opts, query_pairs) = from_url_basic(url)?;
for (key, value) in query_pairs {
if key == "prefer_socket" {
if value == "true" {
opts.set_prefer_socket(true);
} else if value == "false" {
opts.set_prefer_socket(false);
} else {
return Err(UrlError::InvalidValue("prefer_socket".into(), value));
}
} else if key == "verify_peer" {
if cfg!(not(feature = "ssl")) {
return Err(UrlError::FeatureRequired(
"`ssl'".into(),
"verify_peer".into(),
));
} else {
if value == "true" {
opts.set_verify_peer(true);
} else if value == "false" {
opts.set_verify_peer(false);
} else {
return Err(UrlError::InvalidValue("verify_peer".into(), value));
}
}
} else if key == "tcp_keepalive_time_ms" {
match u32::from_str(&*value) {
Ok(tcp_keepalive_time_ms) => {
opts.tcp_keepalive_time = Some(tcp_keepalive_time_ms);
}
_ => {
return Err(UrlError::InvalidValue(
"tcp_keepalive_time_ms".into(),
value,
));
}
}
} else if key == "tcp_connect_timeout_ms" {
match u64::from_str(&*value) {
Ok(tcp_connect_timeout_ms) => {
opts.tcp_connect_timeout = Some(Duration::from_millis(tcp_connect_timeout_ms));
}
_ => {
return Err(UrlError::InvalidValue(
"tcp_connect_timeout_ms".into(),
value,
));
}
}
} else if key == "stmt_cache_size" {
match usize::from_str(&*value) {
Ok(stmt_cache_size) => {
opts.stmt_cache_size = stmt_cache_size;
}
_ => {
return Err(UrlError::InvalidValue("stmt_cache_size".into(), value));
}
}
} else if key == "compress" {
if value == "true" {
opts.compress = true;
} else if value == "false" {
opts.compress = false;
} else {
return Err(UrlError::InvalidValue("compress".into(), value));
}
} else {
return Err(UrlError::UnknownParameter(key));
}
}
Ok(opts)
}
impl<S: AsRef<str>> From<S> for Opts {
fn from(url: S) -> Opts {
match from_url(url.as_ref()) {
Ok(opts) => opts,
Err(err) => panic!("{}", err),
}
}
}
#[cfg(test)]
mod test {
use super::Opts;
#[test]
#[cfg(feature = "ssl")]
fn should_convert_url_into_opts() {
let opts = "mysql://us%20r:p%20w@localhost:3308/db%2dname?prefer_socket=false&verify_peer=true&tcp_keepalive_time_ms=5000";
assert_eq!(
Opts {
user: Some("us r".to_string()),
pass: Some("p w".to_string()),
ip_or_hostname: Some("localhost".to_string()),
tcp_port: 3308,
db_name: Some("db-name".to_string()),
prefer_socket: false,
verify_peer: true,
tcp_keepalive_time: Some(5000),
..Opts::default()
},
opts.into()
);
}
#[test]
#[cfg(not(feature = "ssl"))]
fn should_convert_url_into_opts() {
let opts = "mysql://usr:pw@192.168.1.1:3309/dbname";
assert_eq!(
Opts {
user: Some("usr".to_string()),
pass: Some("pw".to_string()),
ip_or_hostname: Some("192.168.1.1".to_string()),
tcp_port: 3309,
db_name: Some("dbname".to_string()),
..Opts::default()
},
opts.into()
);
}
#[test]
#[should_panic]
fn should_panic_on_invalid_url() {
let opts = "42";
let _: Opts = opts.into();
}
#[test]
#[should_panic]
fn should_panic_on_invalid_scheme() {
let opts = "postgres://localhost";
let _: Opts = opts.into();
}
#[test]
#[should_panic]
fn should_panic_on_unknown_query_param() {
let opts = "mysql://localhost/foo?bar=baz";
let _: Opts = opts.into();
}
#[test]
#[should_panic]
#[cfg(not(feature = "ssl"))]
fn should_panic_if_verify_peer_query_param_requires_feature() {
let opts = "mysql://usr:pw@localhost:3308/dbname?verify_peer=false";
let _: Opts = opts.into();
}
#[test]
#[should_panic]
#[cfg(feature = "ssl")]
fn should_panic_on_invalid_verify_peer_param_value() {
let opts = "mysql://usr:pw@localhost:3308/dbname?verify_peer=invalid";
let _: Opts = opts.into();
}
}