use percent_encoding::percent_decode;
use url::Url;
use std::{
borrow::Cow, collections::HashMap, fmt, hash::Hash, net::SocketAddr, path::Path, time::Duration,
};
use crate::{
consts::CapabilityFlags, Compression, LocalInfileHandler, PoolConstraints, PoolOpts, UrlError,
};
pub const DEFAULT_STMT_CACHE_SIZE: usize = 32;
mod native_tls_opts;
mod rustls_opts;
pub mod pool_opts;
#[cfg(feature = "native-tls")]
pub use native_tls_opts::ClientIdentity;
#[cfg(feature = "rustls")]
pub use rustls_opts::ClientIdentity;
#[derive(Debug, Clone, Eq, PartialEq, Hash, Default)]
pub struct SslOpts {
#[cfg(any(feature = "native-tls", feature = "rustls"))]
client_identity: Option<ClientIdentity>,
root_cert_path: Option<Cow<'static, Path>>,
skip_domain_validation: bool,
accept_invalid_certs: bool,
}
impl SslOpts {
#[cfg(any(feature = "native-tls", feature = "rustls"))]
#[cfg_attr(
docsrs,
doc(cfg(any(
feature = "native-tls",
feature = "rustls-tls",
feature = "rustls-tls-ring"
)))
)]
pub fn with_client_identity(mut self, identity: Option<ClientIdentity>) -> Self {
self.client_identity = identity;
self
}
pub fn with_root_cert_path<T: Into<Cow<'static, Path>>>(
mut self,
root_cert_path: Option<T>,
) -> Self {
self.root_cert_path = root_cert_path.map(Into::into);
self
}
pub fn with_danger_skip_domain_validation(mut self, value: bool) -> Self {
self.skip_domain_validation = value;
self
}
pub fn with_danger_accept_invalid_certs(mut self, value: bool) -> Self {
self.accept_invalid_certs = value;
self
}
#[cfg(any(feature = "native-tls", feature = "rustls"))]
#[cfg_attr(
docsrs,
doc(cfg(any(
feature = "native-tls",
feature = "rustls-tls",
feature = "rustls-tls-ring"
)))
)]
pub fn client_identity(&self) -> Option<&ClientIdentity> {
self.client_identity.as_ref()
}
pub fn root_cert_path(&self) -> Option<&Path> {
self.root_cert_path.as_ref().map(AsRef::as_ref)
}
pub fn skip_domain_validation(&self) -> bool {
self.skip_domain_validation
}
pub fn accept_invalid_certs(&self) -> bool {
self.accept_invalid_certs
}
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub(crate) struct InnerOpts {
ip_or_hostname: url::Host,
tcp_port: u16,
socket: Option<String>,
user: Option<String>,
pass: Option<String>,
db_name: Option<String>,
read_timeout: Option<Duration>,
write_timeout: Option<Duration>,
prefer_socket: bool,
tcp_nodelay: bool,
tcp_keepalive_time: Option<u32>,
#[cfg(any(target_os = "linux", target_os = "macos"))]
tcp_keepalive_probe_interval_secs: Option<u32>,
#[cfg(any(target_os = "linux", target_os = "macos"))]
tcp_keepalive_probe_count: Option<u32>,
#[cfg(target_os = "linux")]
tcp_user_timeout: Option<u32>,
init: Vec<String>,
ssl_opts: Option<SslOpts>,
pool_opts: PoolOpts,
local_infile_handler: Option<LocalInfileHandler>,
tcp_connect_timeout: Option<Duration>,
bind_address: Option<SocketAddr>,
stmt_cache_size: usize,
compress: Option<crate::Compression>,
additional_capabilities: CapabilityFlags,
connect_attrs: Option<HashMap<String, String>>,
secure_auth: bool,
enable_cleartext_plugin: bool,
max_allowed_packet: Option<usize>,
#[cfg(test)]
pub injected_socket: Option<String>,
}
impl Default for InnerOpts {
fn default() -> Self {
InnerOpts {
ip_or_hostname: url::Host::Domain(String::from("localhost")),
tcp_port: 3306,
socket: None,
max_allowed_packet: None,
user: None,
pass: None,
db_name: None,
read_timeout: None,
write_timeout: None,
prefer_socket: true,
init: vec![],
ssl_opts: None,
pool_opts: PoolOpts::default(),
tcp_keepalive_time: None,
#[cfg(any(target_os = "linux", target_os = "macos",))]
tcp_keepalive_probe_interval_secs: None,
#[cfg(any(target_os = "linux", target_os = "macos",))]
tcp_keepalive_probe_count: None,
#[cfg(target_os = "linux")]
tcp_user_timeout: None,
tcp_nodelay: true,
local_infile_handler: None,
tcp_connect_timeout: None,
bind_address: None,
stmt_cache_size: DEFAULT_STMT_CACHE_SIZE,
compress: None,
additional_capabilities: CapabilityFlags::empty(),
connect_attrs: Some(HashMap::new()),
secure_auth: true,
enable_cleartext_plugin: false,
#[cfg(test)]
injected_socket: None,
}
}
}
impl TryFrom<&'_ str> for Opts {
type Error = UrlError;
fn try_from(url: &'_ str) -> Result<Self, Self::Error> {
Opts::from_url(url)
}
}
#[derive(Clone, Eq, PartialEq, Debug, Default)]
pub struct Opts(pub(crate) Box<InnerOpts>);
impl Opts {
#[doc(hidden)]
pub fn addr_is_loopback(&self) -> bool {
match self.0.ip_or_hostname {
url::Host::Domain(ref name) => name == "localhost",
url::Host::Ipv4(ref addr) => addr.is_loopback(),
url::Host::Ipv6(ref addr) => addr.is_loopback(),
}
}
pub fn from_url(url: &str) -> Result<Opts, UrlError> {
from_url(url)
}
pub(crate) fn get_host(&self) -> url::Host {
self.0.ip_or_hostname.clone()
}
pub fn get_ip_or_hostname(&self) -> Cow<'_, str> {
self.0.ip_or_hostname.to_string().into()
}
pub fn get_tcp_port(&self) -> u16 {
self.0.tcp_port
}
pub fn get_socket(&self) -> Option<&str> {
self.0.socket.as_deref()
}
pub fn get_max_allowed_packet(&self) -> Option<usize> {
self.0.max_allowed_packet
}
pub fn get_user(&self) -> Option<&str> {
self.0.user.as_deref()
}
pub fn get_pass(&self) -> Option<&str> {
self.0.pass.as_deref()
}
pub fn get_db_name(&self) -> Option<&str> {
self.0.db_name.as_deref()
}
pub fn get_read_timeout(&self) -> Option<&Duration> {
self.0.read_timeout.as_ref()
}
pub fn get_write_timeout(&self) -> Option<&Duration> {
self.0.write_timeout.as_ref()
}
pub fn get_prefer_socket(&self) -> bool {
self.0.prefer_socket
}
pub fn get_init(&self) -> Vec<String> {
self.0.init.clone()
}
pub fn get_ssl_opts(&self) -> Option<&SslOpts> {
self.0.ssl_opts.as_ref()
}
pub fn get_pool_opts(&self) -> &PoolOpts {
&self.0.pool_opts
}
pub fn get_tcp_nodelay(&self) -> bool {
self.0.tcp_nodelay
}
pub fn get_tcp_keepalive_time_ms(&self) -> Option<u32> {
self.0.tcp_keepalive_time
}
#[cfg(any(target_os = "linux", target_os = "macos",))]
pub fn get_tcp_keepalive_probe_interval_secs(&self) -> Option<u32> {
self.0.tcp_keepalive_probe_interval_secs
}
#[cfg(any(target_os = "linux", target_os = "macos",))]
pub fn get_tcp_keepalive_probe_count(&self) -> Option<u32> {
self.0.tcp_keepalive_probe_count
}
#[cfg(target_os = "linux")]
pub fn get_tcp_user_timeout_ms(&self) -> Option<u32> {
self.0.tcp_user_timeout
}
pub fn get_local_infile_handler(&self) -> Option<&LocalInfileHandler> {
self.0.local_infile_handler.as_ref()
}
pub fn get_tcp_connect_timeout(&self) -> Option<Duration> {
self.0.tcp_connect_timeout
}
pub fn bind_address(&self) -> Option<&SocketAddr> {
self.0.bind_address.as_ref()
}
pub fn get_stmt_cache_size(&self) -> usize {
self.0.stmt_cache_size
}
pub fn get_compress(&self) -> Option<crate::Compression> {
self.0.compress
}
pub fn get_additional_capabilities(&self) -> CapabilityFlags {
self.0.additional_capabilities
}
pub fn get_connect_attrs(&self) -> Option<&HashMap<String, String>> {
self.0.connect_attrs.as_ref()
}
pub fn get_secure_auth(&self) -> bool {
self.0.secure_auth
}
pub fn get_enable_cleartext_plugin(&self) -> bool {
self.0.enable_cleartext_plugin
}
}
#[derive(Debug, Clone, PartialEq, Default)]
pub struct OptsBuilder {
opts: Opts,
}
impl OptsBuilder {
pub fn new() -> Self {
OptsBuilder::default()
}
pub fn from_opts<T: Into<Opts>>(opts: T) -> Self {
OptsBuilder { opts: opts.into() }
}
pub fn from_hash_map(mut self, client: &HashMap<String, String>) -> Result<Self, UrlError> {
let mut pool_min = PoolConstraints::DEFAULT.min();
let mut pool_max = PoolConstraints::DEFAULT.max();
for (key, value) in client.iter() {
match key.as_str() {
"pool_min" => match value.parse::<usize>() {
Ok(parsed) => pool_min = parsed,
Err(_) => {
return Err(UrlError::InvalidValue(key.to_string(), value.to_string()))
}
},
"pool_max" => match value.parse::<usize>() {
Ok(parsed) => pool_max = parsed,
Err(_) => {
return Err(UrlError::InvalidValue(key.to_string(), value.to_string()))
}
},
"user" => self.opts.0.user = Some(value.to_string()),
"password" => self.opts.0.pass = Some(value.to_string()),
"host" => {
let host = url::Host::parse(value)
.unwrap_or_else(|_| url::Host::Domain(value.to_owned()));
self.opts.0.ip_or_hostname = host;
}
"port" => match value.parse::<u16>() {
Ok(parsed) => self.opts.0.tcp_port = parsed,
Err(_) => {
return Err(UrlError::InvalidValue(key.to_string(), value.to_string()))
}
},
"socket" => self.opts.0.socket = Some(value.to_string()),
"db_name" => self.opts.0.db_name = Some(value.to_string()),
"prefer_socket" => {
match value.parse::<bool>() {
Ok(parsed) => self.opts.0.prefer_socket = parsed,
Err(_) => {
return Err(UrlError::InvalidValue(key.to_string(), value.to_string()))
}
}
}
"enable_cleartext_plugin" => match value.parse::<bool>() {
Ok(parsed) => self.opts.0.enable_cleartext_plugin = parsed,
Err(_) => {
return Err(UrlError::InvalidValue(key.to_string(), value.to_string()))
}
},
"secure_auth" => match value.parse::<bool>() {
Ok(parsed) => self.opts.0.secure_auth = parsed,
Err(_) => {
return Err(UrlError::InvalidValue(key.to_string(), value.to_string()))
}
},
"tcp_keepalive_time_ms" => {
self.opts.0.tcp_keepalive_time = match value.parse::<u32>() {
Ok(val) => Some(val),
_ => {
return Err(UrlError::InvalidValue(key.to_string(), value.to_string()))
}
}
}
#[cfg(any(target_os = "linux", target_os = "macos",))]
"tcp_keepalive_probe_interval_secs" => {
self.opts.0.tcp_keepalive_probe_interval_secs = match value.parse::<u32>() {
Ok(val) => Some(val),
_ => {
return Err(UrlError::InvalidValue(key.to_string(), value.to_string()))
}
}
}
#[cfg(any(target_os = "linux", target_os = "macos",))]
"tcp_keepalive_probe_count" => {
self.opts.0.tcp_keepalive_probe_count = match value.parse::<u32>() {
Ok(val) => Some(val),
_ => {
return Err(UrlError::InvalidValue(key.to_string(), value.to_string()))
}
}
}
#[cfg(target_os = "linux")]
"tcp_user_timeout_ms" => {
self.opts.0.tcp_user_timeout = match value.parse::<u32>() {
Ok(val) => Some(val),
_ => {
return Err(UrlError::InvalidValue(key.to_string(), value.to_string()))
}
}
}
"compress" => match value.parse::<u32>() {
Ok(val) => self.opts.0.compress = Some(Compression::new(val)),
Err(_) => {
match value.as_str() {
"fast" => self.opts.0.compress = Some(Compression::fast()),
"best" => self.opts.0.compress = Some(Compression::best()),
"true" => self.opts.0.compress = Some(Compression::default()),
_ => {
return Err(UrlError::InvalidValue(
key.to_string(),
value.to_string(),
)); }
}
}
},
"tcp_connect_timeout_ms" => {
self.opts.0.tcp_connect_timeout = match value.parse::<u64>() {
Ok(val) => Some(Duration::from_millis(val)),
_ => {
return Err(UrlError::InvalidValue(key.to_string(), value.to_string()))
}
}
}
"stmt_cache_size" => match value.parse::<usize>() {
Ok(parsed) => self.opts.0.stmt_cache_size = parsed,
Err(_) => {
return Err(UrlError::InvalidValue(key.to_string(), value.to_string()))
}
},
"reset_connection" => match value.parse::<bool>() {
Ok(parsed) => {
self.opts.0.pool_opts = self.opts.0.pool_opts.with_reset_connection(parsed)
}
Err(_) => {
return Err(UrlError::InvalidValue(key.to_string(), value.to_string()))
}
},
"check_health" => match value.parse::<bool>() {
Ok(parsed) => {
self.opts.0.pool_opts = self.opts.0.pool_opts.with_check_health(parsed)
}
Err(_) => {
return Err(UrlError::InvalidValue(key.to_string(), value.to_string()))
}
},
"max_allowed_packet" => match value.parse::<usize>() {
Ok(parsed) => self.opts.0.max_allowed_packet = Some(parsed),
Err(_) => {
return Err(UrlError::InvalidValue(key.to_string(), value.to_string()))
}
},
_ => {
return Err(UrlError::UnknownParameter(key.to_string()));
}
}
}
if let Some(pool_constraints) = PoolConstraints::new(pool_min, pool_max) {
self.opts.0.pool_opts = self.opts.0.pool_opts.with_constraints(pool_constraints);
} else {
return Err(UrlError::InvalidPoolConstraints {
min: pool_min,
max: pool_max,
});
}
Ok(self)
}
pub fn ip_or_hostname<T: Into<String>>(mut self, ip_or_hostname: Option<T>) -> Self {
let new = ip_or_hostname
.map(Into::into)
.unwrap_or_else(|| "127.0.0.1".into());
self.opts.0.ip_or_hostname =
url::Host::parse(&new).unwrap_or_else(|_| url::Host::Domain(new.to_owned()));
self
}
pub fn tcp_port(mut self, tcp_port: u16) -> Self {
self.opts.0.tcp_port = tcp_port;
self
}
pub fn socket<T: Into<String>>(mut self, socket: Option<T>) -> Self {
self.opts.0.socket = socket.map(Into::into);
self
}
pub fn max_allowed_packet(mut self, max_allowed_packet: Option<usize>) -> Self {
self.opts.0.max_allowed_packet = max_allowed_packet.map(|x| x.clamp(1024, 1073741824));
self
}
pub fn user<T: Into<String>>(mut self, user: Option<T>) -> Self {
self.opts.0.user = user.map(Into::into);
self
}
pub fn pass<T: Into<String>>(mut self, pass: Option<T>) -> Self {
self.opts.0.pass = pass.map(Into::into);
self
}
pub fn db_name<T: Into<String>>(mut self, db_name: Option<T>) -> Self {
self.opts.0.db_name = db_name.map(Into::into);
self
}
pub fn read_timeout(mut self, read_timeout: Option<Duration>) -> Self {
self.opts.0.read_timeout = read_timeout;
self
}
pub fn write_timeout(mut self, write_timeout: Option<Duration>) -> Self {
self.opts.0.write_timeout = write_timeout;
self
}
pub fn tcp_keepalive_time_ms(mut self, tcp_keepalive_time_ms: Option<u32>) -> Self {
self.opts.0.tcp_keepalive_time = tcp_keepalive_time_ms;
self
}
#[cfg(any(target_os = "linux", target_os = "macos",))]
pub fn tcp_keepalive_probe_interval_secs(
mut self,
tcp_keepalive_probe_interval_secs: Option<u32>,
) -> Self {
self.opts.0.tcp_keepalive_probe_interval_secs = tcp_keepalive_probe_interval_secs;
self
}
#[cfg(any(target_os = "linux", target_os = "macos",))]
pub fn tcp_keepalive_probe_count(mut self, tcp_keepalive_probe_count: Option<u32>) -> Self {
self.opts.0.tcp_keepalive_probe_count = tcp_keepalive_probe_count;
self
}
#[cfg(target_os = "linux")]
pub fn tcp_user_timeout_ms(mut self, tcp_user_timeout_ms: Option<u32>) -> Self {
self.opts.0.tcp_user_timeout = tcp_user_timeout_ms;
self
}
pub fn tcp_nodelay(mut self, nodelay: bool) -> Self {
self.opts.0.tcp_nodelay = nodelay;
self
}
pub fn prefer_socket(mut self, prefer_socket: bool) -> Self {
self.opts.0.prefer_socket = prefer_socket;
self
}
pub fn init<T: Into<String>>(mut self, init: Vec<T>) -> Self {
self.opts.0.init = init.into_iter().map(Into::into).collect();
self
}
pub fn ssl_opts<T: Into<Option<SslOpts>>>(mut self, ssl_opts: T) -> Self {
self.opts.0.ssl_opts = ssl_opts.into();
self
}
pub fn pool_opts<T: Into<Option<PoolOpts>>>(mut self, pool_opts: T) -> Self {
self.opts.0.pool_opts = pool_opts.into().unwrap_or_default();
self
}
pub fn local_infile_handler(mut self, handler: Option<LocalInfileHandler>) -> Self {
self.opts.0.local_infile_handler = handler;
self
}
pub fn tcp_connect_timeout(mut self, timeout: Option<Duration>) -> Self {
self.opts.0.tcp_connect_timeout = timeout;
self
}
pub fn bind_address<T>(mut self, bind_address: Option<T>) -> Self
where
T: Into<SocketAddr>,
{
self.opts.0.bind_address = bind_address.map(Into::into);
self
}
pub fn stmt_cache_size<T>(mut self, cache_size: T) -> Self
where
T: Into<Option<usize>>,
{
self.opts.0.stmt_cache_size = cache_size.into().unwrap_or(128);
self
}
pub fn compress(mut self, compress: Option<crate::Compression>) -> Self {
self.opts.0.compress = compress;
self
}
pub fn additional_capabilities(mut self, additional_capabilities: CapabilityFlags) -> Self {
let forbidden_flags: CapabilityFlags = CapabilityFlags::CLIENT_PROTOCOL_41
| CapabilityFlags::CLIENT_SSL
| CapabilityFlags::CLIENT_COMPRESS
| CapabilityFlags::CLIENT_SECURE_CONNECTION
| CapabilityFlags::CLIENT_LONG_PASSWORD
| CapabilityFlags::CLIENT_TRANSACTIONS
| CapabilityFlags::CLIENT_LOCAL_FILES
| CapabilityFlags::CLIENT_MULTI_STATEMENTS
| CapabilityFlags::CLIENT_MULTI_RESULTS
| CapabilityFlags::CLIENT_PS_MULTI_RESULTS;
self.opts.0.additional_capabilities = additional_capabilities & !forbidden_flags;
self
}
pub fn connect_attrs<T1: Into<String> + Eq + Hash, T2: Into<String>>(
mut self,
connect_attrs: Option<HashMap<T1, T2>>,
) -> Self {
if let Some(connect_attrs) = connect_attrs {
let mut attrs = HashMap::with_capacity(connect_attrs.len());
for (name, value) in connect_attrs {
let name = name.into();
if !name.starts_with('_') {
attrs.insert(name, value.into());
}
}
self.opts.0.connect_attrs = Some(attrs);
} else {
self.opts.0.connect_attrs = None;
}
self
}
pub fn secure_auth(mut self, secure_auth: bool) -> Self {
self.opts.0.secure_auth = secure_auth;
self
}
pub fn enable_cleartext_plugin(mut self, enable_cleartext_plugin: bool) -> Self {
self.opts.0.enable_cleartext_plugin = enable_cleartext_plugin;
self
}
}
impl From<OptsBuilder> for Opts {
fn from(builder: OptsBuilder) -> Opts {
builder.opts
}
}
fn get_opts_user_from_url(url: &Url) -> Option<String> {
let user = url.username();
if !user.is_empty() {
Some(
percent_decode(user.as_ref())
.decode_utf8_lossy()
.into_owned(),
)
} else {
None
}
}
fn get_opts_pass_from_url(url: &Url) -> Option<String> {
url.password().map(|pass| {
percent_decode(pass.as_ref())
.decode_utf8_lossy()
.into_owned()
})
}
fn get_opts_db_name_from_url(url: &Url) -> Option<String> {
if let Some(mut segments) = url.path_segments() {
segments
.next()
.filter(|&db_name| !db_name.is_empty())
.map(|db_name| {
percent_decode(db_name.as_ref())
.decode_utf8_lossy()
.into_owned()
})
} else {
None
}
}
fn from_url_basic(url_str: &str) -> Result<(Opts, Vec<(String, String)>), UrlError> {
let url = Url::parse(url_str)?;
if url.scheme() != "mysql" {
return Err(UrlError::UnsupportedScheme(url.scheme().to_string()));
}
if url.cannot_be_a_base() {
return Err(UrlError::BadUrl);
}
let user = get_opts_user_from_url(&url);
let pass = get_opts_pass_from_url(&url);
let ip_or_hostname = url
.host()
.ok_or(UrlError::BadUrl)
.and_then(|host| url::Host::parse(&host.to_string()).map_err(|_| UrlError::BadUrl))?;
let tcp_port = url.port().unwrap_or(3306);
let db_name = get_opts_db_name_from_url(&url);
let query_pairs = url.query_pairs().into_owned().collect();
let opts = Opts(Box::new(InnerOpts {
user,
pass,
ip_or_hostname,
tcp_port,
db_name,
..InnerOpts::default()
}));
Ok((opts, query_pairs))
}
fn from_url(url: &str) -> Result<Opts, UrlError> {
let (opts, query_pairs) = from_url_basic(url)?;
let hash_map = query_pairs.into_iter().collect::<HashMap<String, String>>();
OptsBuilder::from_opts(opts)
.from_hash_map(&hash_map)
.map(Into::into)
}
#[derive(Clone, Eq, PartialEq)]
pub struct ChangeUserOpts {
user: Option<Option<String>>,
pass: Option<Option<String>>,
db_name: Option<Option<String>>,
}
impl ChangeUserOpts {
pub const DEFAULT: Self = Self {
user: None,
pass: None,
db_name: None,
};
pub(crate) fn update_opts(self, opts: &mut Opts) {
if self.user.is_none() && self.pass.is_none() && self.db_name.is_none() {
return;
}
let mut builder = OptsBuilder::from_opts(opts.clone());
if let Some(user) = self.user {
builder = builder.user(user);
}
if let Some(pass) = self.pass {
builder = builder.pass(pass);
}
if let Some(db_name) = self.db_name {
builder = builder.db_name(db_name);
}
*opts = Opts::from(builder);
}
pub fn new() -> Self {
Self {
user: None,
pass: None,
db_name: None,
}
}
pub fn with_user(mut self, user: Option<String>) -> Self {
self.user = Some(user);
self
}
pub fn with_pass(mut self, pass: Option<String>) -> Self {
self.pass = Some(pass);
self
}
pub fn with_db_name(mut self, db_name: Option<String>) -> Self {
self.db_name = Some(db_name);
self
}
pub fn user(&self) -> Option<Option<&str>> {
self.user.as_ref().map(|x| x.as_deref())
}
pub fn pass(&self) -> Option<Option<&str>> {
self.pass.as_ref().map(|x| x.as_deref())
}
pub fn db_name(&self) -> Option<Option<&str>> {
self.db_name.as_ref().map(|x| x.as_deref())
}
}
impl Default for ChangeUserOpts {
fn default() -> Self {
Self::new()
}
}
impl fmt::Debug for ChangeUserOpts {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ChangeUserOpts")
.field("user", &self.user)
.field(
"pass",
&self.pass.as_ref().map(|x| x.as_ref().map(|_| "...")),
)
.field("db_name", &self.db_name)
.finish()
}
}
#[cfg(test)]
mod test {
use mysql_common::proto::codec::Compression;
use std::time::Duration;
use super::{InnerOpts, Opts, OptsBuilder};
#[allow(dead_code)]
fn assert_conn_from_url_opts_optsbuilder(url: &str, opts: Opts, opts_builder: OptsBuilder) {
crate::Conn::new(url).unwrap();
crate::Conn::new(opts.clone()).unwrap();
crate::Conn::new(opts_builder.clone()).unwrap();
crate::Pool::new(url).unwrap();
crate::Pool::new(opts).unwrap();
crate::Pool::new(opts_builder).unwrap();
}
#[test]
fn should_report_empty_url_database_as_none() {
let opt = Opts::from_url("mysql://localhost/").unwrap();
assert_eq!(opt.get_db_name(), None);
}
#[test]
fn should_convert_url_into_opts() {
#[cfg(any(target_os = "linux", target_os = "macos",))]
let tcp_keepalive_probe_interval_secs = "&tcp_keepalive_probe_interval_secs=8";
#[cfg(not(any(target_os = "linux", target_os = "macos",)))]
let tcp_keepalive_probe_interval_secs = "";
#[cfg(any(target_os = "linux", target_os = "macos",))]
let tcp_keepalive_probe_count = "&tcp_keepalive_probe_count=5";
#[cfg(not(any(target_os = "linux", target_os = "macos",)))]
let tcp_keepalive_probe_count = "";
#[cfg(target_os = "linux")]
let tcp_user_timeout = "&tcp_user_timeout_ms=6000";
#[cfg(not(target_os = "linux"))]
let tcp_user_timeout = "";
let opts = format!(
"mysql://us%20r:p%20w@localhost:3308/db%2dname?prefer_socket=false&tcp_keepalive_time_ms=5000{}{}{}&socket=%2Ftmp%2Fmysql.sock&compress=8",
tcp_keepalive_probe_interval_secs,
tcp_keepalive_probe_count,
tcp_user_timeout,
);
assert_eq!(
Opts(Box::new(InnerOpts {
user: Some("us r".to_string()),
pass: Some("p w".to_string()),
ip_or_hostname: url::Host::Domain("localhost".to_string()),
tcp_port: 3308,
db_name: Some("db-name".to_string()),
prefer_socket: false,
tcp_keepalive_time: Some(5000),
#[cfg(any(target_os = "linux", target_os = "macos",))]
tcp_keepalive_probe_interval_secs: Some(8),
#[cfg(any(target_os = "linux", target_os = "macos",))]
tcp_keepalive_probe_count: Some(5),
#[cfg(target_os = "linux")]
tcp_user_timeout: Some(6000),
socket: Some("/tmp/mysql.sock".into()),
compress: Some(Compression::new(8)),
..InnerOpts::default()
})),
Opts::from_url(&opts).unwrap(),
);
}
#[test]
#[should_panic]
fn should_panic_on_invalid_url() {
let opts = "42";
Opts::from_url(opts).unwrap();
}
#[test]
#[should_panic]
fn should_panic_on_invalid_scheme() {
let opts = "postgres://localhost";
Opts::from_url(opts).unwrap();
}
#[test]
#[should_panic]
fn should_panic_on_unknown_query_param() {
let opts = "mysql://localhost/foo?bar=baz";
Opts::from_url(opts).unwrap();
}
#[test]
fn should_read_hashmap_into_opts() {
use crate::OptsBuilder;
macro_rules! map(
{ $($key:expr => $value:expr), + }=> {
{
let mut h = std::collections::HashMap::new();
$(
h.insert($key, $value);
)+
h
}
};
);
let mut cnf_map = map! {
"user".to_string() => "test".to_string(),
"password".to_string() => "password".to_string(),
"host".to_string() => "127.0.0.1".to_string(),
"port".to_string() => "8080".to_string(),
"db_name".to_string() => "test_db".to_string(),
"prefer_socket".to_string() => "false".to_string(),
"tcp_keepalive_time_ms".to_string() => "5000".to_string(),
"compress".to_string() => "best".to_string(),
"tcp_connect_timeout_ms".to_string() => "1000".to_string(),
"stmt_cache_size".to_string() => "33".to_string(),
"max_allowed_packet".to_string() => "65536".to_string()
};
#[cfg(any(target_os = "linux", target_os = "macos",))]
cnf_map.insert(
"tcp_keepalive_probe_interval_secs".to_string(),
"8".to_string(),
);
#[cfg(any(target_os = "linux", target_os = "macos",))]
cnf_map.insert("tcp_keepalive_probe_count".to_string(), "5".to_string());
let parsed_opts = OptsBuilder::new().from_hash_map(&cnf_map).unwrap();
assert_eq!(parsed_opts.opts.get_user(), Some("test"));
assert_eq!(parsed_opts.opts.get_pass(), Some("password"));
assert_eq!(parsed_opts.opts.get_ip_or_hostname(), "127.0.0.1");
assert_eq!(parsed_opts.opts.get_tcp_port(), 8080);
assert_eq!(parsed_opts.opts.get_db_name(), Some("test_db"));
assert_eq!(parsed_opts.opts.get_max_allowed_packet(), Some(65536));
assert!(!parsed_opts.opts.get_prefer_socket());
assert_eq!(parsed_opts.opts.get_tcp_keepalive_time_ms(), Some(5000));
#[cfg(any(target_os = "linux", target_os = "macos",))]
assert_eq!(
parsed_opts.opts.get_tcp_keepalive_probe_interval_secs(),
Some(8)
);
#[cfg(any(target_os = "linux", target_os = "macos",))]
assert_eq!(parsed_opts.opts.get_tcp_keepalive_probe_count(), Some(5));
assert_eq!(
parsed_opts.opts.get_compress(),
Some(crate::Compression::best())
);
assert_eq!(
parsed_opts.opts.get_tcp_connect_timeout(),
Some(Duration::from_millis(1000))
);
assert_eq!(parsed_opts.opts.get_stmt_cache_size(), 33);
}
#[test]
fn should_have_url_err() {
use crate::OptsBuilder;
use crate::UrlError;
macro_rules! map(
{ $($key:expr => $value:expr), + }=> {
{
let mut h = std::collections::HashMap::new();
$(
h.insert($key, $value);
)+
h
}
};
);
let cnf_map = map! {
"user".to_string() => "test".to_string(),
"password".to_string() => "password".to_string(),
"host".to_string() => "127.0.0.1".to_string(),
"port".to_string() => "NOTAPORT".to_string(),
"db_name".to_string() => "test_db".to_string(),
"prefer_socket".to_string() => "false".to_string(),
"tcp_keepalive_time_ms".to_string() => "5000".to_string(),
"compress".to_string() => "best".to_string(),
"tcp_connect_timeout_ms".to_string() => "1000".to_string(),
"stmt_cache_size".to_string() => "33".to_string()
};
let parsed = OptsBuilder::new().from_hash_map(&cnf_map);
assert_eq!(
parsed,
Err(UrlError::InvalidValue(
"port".to_string(),
"NOTAPORT".to_string()
))
);
}
}