use percent_encoding::percent_decode;
use url::Url;
use std::{
borrow::Cow,
net::{Ipv4Addr, Ipv6Addr},
path::Path,
str::FromStr,
sync::Arc,
};
use crate::{
consts::CapabilityFlags,
error::*,
local_infile_handler::{LocalInfileHandler, LocalInfileHandlerObject},
};
const DEFAULT_POOL_CONSTRAINTS: PoolConstraints = PoolConstraints { min: 10, max: 100 };
const_assert!(
_DEFAULT_POOL_CONSTRAINTS_ARE_CORRECT,
DEFAULT_POOL_CONSTRAINTS.min <= DEFAULT_POOL_CONSTRAINTS.max,
);
const DEFAULT_STMT_CACHE_SIZE: usize = 10;
#[derive(Debug, Clone, Eq, PartialEq, Hash, Default)]
pub struct SslOpts {
pkcs12_path: Option<Cow<'static, Path>>,
password: Option<Cow<'static, str>>,
root_cert_path: Option<Cow<'static, Path>>,
skip_domain_validation: bool,
accept_invalid_certs: bool,
}
impl SslOpts {
pub fn set_pkcs12_path<T: Into<Cow<'static, Path>>>(
&mut self,
pkcs12_path: Option<T>,
) -> &mut Self {
self.pkcs12_path = pkcs12_path.map(Into::into);
self
}
pub fn set_password<T: Into<Cow<'static, str>>>(&mut self, password: Option<T>) -> &mut Self {
self.password = password.map(Into::into);
self
}
pub fn set_root_cert_path<T: Into<Cow<'static, Path>>>(
&mut self,
root_cert_path: Option<T>,
) -> &mut Self {
self.root_cert_path = root_cert_path.map(Into::into);
self
}
pub fn set_danger_skip_domain_validation(&mut self, value: bool) -> &mut Self {
self.skip_domain_validation = value;
self
}
pub fn set_danger_accept_invalid_certs(&mut self, value: bool) -> &mut Self {
self.accept_invalid_certs = value;
self
}
pub fn pkcs12_path(&self) -> Option<&Path> {
self.pkcs12_path.as_ref().map(|x| x.as_ref())
}
pub fn password(&self) -> Option<&str> {
self.password.as_ref().map(AsRef::as_ref)
}
pub fn root_cert_path(&self) -> Option<&Path> {
self.root_cert_path.as_ref().map(AsRef::as_ref)
}
pub fn skip_domain_validation(&self) -> bool {
self.skip_domain_validation
}
pub fn accept_invalid_certs(&self) -> bool {
self.accept_invalid_certs
}
}
#[derive(Clone, Eq, PartialEq, Debug)]
pub struct InnerOpts {
ip_or_hostname: String,
tcp_port: u16,
user: Option<String>,
pass: Option<String>,
db_name: Option<String>,
tcp_keepalive: Option<u32>,
tcp_nodelay: bool,
local_infile_handler: Option<LocalInfileHandlerObject>,
pool_constraints: PoolConstraints,
conn_ttl: Option<u32>,
init: Vec<String>,
stmt_cache_size: usize,
ssl_opts: Option<SslOpts>,
prefer_socket: bool,
socket: Option<String>,
}
#[derive(Clone, Eq, PartialEq, Debug, Default)]
pub struct Opts {
inner: Arc<InnerOpts>,
}
impl Opts {
#[doc(hidden)]
pub fn addr_is_loopback(&self) -> bool {
let v4addr: Option<Ipv4Addr> = FromStr::from_str(self.inner.ip_or_hostname.as_ref()).ok();
let v6addr: Option<Ipv6Addr> = FromStr::from_str(self.inner.ip_or_hostname.as_ref()).ok();
if let Some(addr) = v4addr {
addr.is_loopback()
} else if let Some(addr) = v6addr {
addr.is_loopback()
} else {
self.inner.ip_or_hostname == "localhost"
}
}
pub fn from_url(url: &str) -> std::result::Result<Opts, UrlError> {
Ok(Opts {
inner: Arc::new(from_url(url)?),
})
}
pub fn get_ip_or_hostname(&self) -> &str {
&*self.inner.ip_or_hostname
}
pub fn get_tcp_port(&self) -> u16 {
self.inner.tcp_port
}
pub fn get_user(&self) -> Option<&str> {
self.inner.user.as_ref().map(AsRef::as_ref)
}
pub fn get_pass(&self) -> Option<&str> {
self.inner.pass.as_ref().map(AsRef::as_ref)
}
pub fn get_db_name(&self) -> Option<&str> {
self.inner.db_name.as_ref().map(AsRef::as_ref)
}
pub fn get_init(&self) -> &[String] {
self.inner.init.as_ref()
}
pub fn get_tcp_keepalive(&self) -> Option<u32> {
self.inner.tcp_keepalive
}
pub fn get_tcp_nodelay(&self) -> bool {
self.inner.tcp_nodelay
}
pub fn get_local_infile_handler(&self) -> Option<Arc<dyn LocalInfileHandler>> {
self.inner
.local_infile_handler
.as_ref()
.map(|x| x.clone_inner())
}
pub fn get_pool_constraints(&self) -> &PoolConstraints {
&self.inner.pool_constraints
}
pub fn get_conn_ttl(&self) -> Option<u32> {
self.inner.conn_ttl
}
pub fn get_stmt_cache_size(&self) -> usize {
self.inner.stmt_cache_size
}
pub fn get_ssl_opts(&self) -> Option<&SslOpts> {
self.inner.ssl_opts.as_ref()
}
pub fn get_perfer_socket(&self) -> bool {
self.inner.prefer_socket
}
pub fn get_prefer_socket(&self) -> bool {
self.inner.prefer_socket
}
pub fn get_socket(&self) -> Option<&str> {
self.inner.socket.as_ref().map(|x| &**x)
}
pub(crate) fn get_capabilities(&self) -> CapabilityFlags {
let mut out = CapabilityFlags::CLIENT_PROTOCOL_41
| CapabilityFlags::CLIENT_SECURE_CONNECTION
| CapabilityFlags::CLIENT_LONG_PASSWORD
| CapabilityFlags::CLIENT_TRANSACTIONS
| CapabilityFlags::CLIENT_LOCAL_FILES
| CapabilityFlags::CLIENT_MULTI_STATEMENTS
| CapabilityFlags::CLIENT_MULTI_RESULTS
| CapabilityFlags::CLIENT_PS_MULTI_RESULTS
| CapabilityFlags::CLIENT_DEPRECATE_EOF
| CapabilityFlags::CLIENT_PLUGIN_AUTH;
if self.inner.db_name.is_some() {
out |= CapabilityFlags::CLIENT_CONNECT_WITH_DB;
}
if self.inner.ssl_opts.is_some() {
out |= CapabilityFlags::CLIENT_SSL;
}
out
}
}
impl Default for InnerOpts {
fn default() -> InnerOpts {
InnerOpts {
ip_or_hostname: "127.0.0.1".to_string(),
tcp_port: 3306,
user: None,
pass: None,
db_name: None,
init: vec![],
tcp_keepalive: None,
tcp_nodelay: true,
local_infile_handler: None,
pool_constraints: Default::default(),
conn_ttl: None,
stmt_cache_size: DEFAULT_STMT_CACHE_SIZE,
ssl_opts: None,
prefer_socket: true,
socket: None,
}
}
}
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
pub struct PoolConstraints {
min: usize,
max: usize,
}
impl PoolConstraints {
pub fn new(min: usize, max: usize) -> Option<PoolConstraints> {
if min <= max {
Some(PoolConstraints { min, max })
} else {
None
}
}
pub fn min(&self) -> usize {
self.min
}
pub fn max(&self) -> usize {
self.max
}
}
impl Default for PoolConstraints {
fn default() -> Self {
DEFAULT_POOL_CONSTRAINTS
}
}
impl From<PoolConstraints> for (usize, usize) {
fn from(PoolConstraints { min, max }: PoolConstraints) -> Self {
(min, max)
}
}
#[derive(Debug, Clone, Eq, PartialEq, Default)]
pub struct OptsBuilder {
opts: InnerOpts,
}
impl OptsBuilder {
pub fn new() -> Self {
OptsBuilder::default()
}
pub fn from_opts<T: Into<Opts>>(opts: T) -> Self {
OptsBuilder {
opts: (*opts.into().inner).clone(),
}
}
pub fn ip_or_hostname<T: Into<String>>(&mut self, ip_or_hostname: T) -> &mut Self {
self.opts.ip_or_hostname = ip_or_hostname.into();
self
}
pub fn tcp_port(&mut self, tcp_port: u16) -> &mut Self {
self.opts.tcp_port = tcp_port;
self
}
pub fn user<T: Into<String>>(&mut self, user: Option<T>) -> &mut Self {
self.opts.user = user.map(Into::into);
self
}
pub fn pass<T: Into<String>>(&mut self, pass: Option<T>) -> &mut Self {
self.opts.pass = pass.map(Into::into);
self
}
pub fn db_name<T: Into<String>>(&mut self, db_name: Option<T>) -> &mut Self {
self.opts.db_name = db_name.map(Into::into);
self
}
pub fn init<T: Into<String>>(&mut self, init: Vec<T>) -> &mut Self {
self.opts.init = init.into_iter().map(Into::into).collect();
self
}
pub fn tcp_keepalive<T: Into<u32>>(&mut self, tcp_keepalive: Option<T>) -> &mut Self {
self.opts.tcp_keepalive = tcp_keepalive.map(Into::into);
self
}
pub fn tcp_nodelay(&mut self, nodelay: bool) -> &mut Self {
self.opts.tcp_nodelay = nodelay;
self
}
pub fn local_infile_handler<T>(&mut self, handler: Option<T>) -> &mut Self
where
T: LocalInfileHandler + 'static,
{
self.opts.local_infile_handler = handler.map(LocalInfileHandlerObject::new);
self
}
pub fn pool_constraints(&mut self, pool_constraints: Option<PoolConstraints>) -> &mut Self {
self.opts.pool_constraints = pool_constraints.unwrap_or(DEFAULT_POOL_CONSTRAINTS);
self
}
pub fn conn_ttl<T: Into<u32>>(&mut self, conn_ttl: Option<T>) -> &mut Self {
self.opts.conn_ttl = conn_ttl.map(Into::into);
self
}
pub fn stmt_cache_size<T>(&mut self, cache_size: T) -> &mut Self
where
T: Into<Option<usize>>,
{
self.opts.stmt_cache_size = cache_size.into().unwrap_or(DEFAULT_STMT_CACHE_SIZE);
self
}
pub fn ssl_opts<T: Into<Option<SslOpts>>>(&mut self, ssl_opts: T) -> &mut Self {
self.opts.ssl_opts = ssl_opts.into();
self
}
pub fn prefer_socket<T: Into<Option<bool>>>(&mut self, prefer_socket: T) -> &mut Self {
self.opts.prefer_socket = prefer_socket.into().unwrap_or(true);
self
}
pub fn socket<T: Into<String>>(&mut self, socket: Option<T>) -> &mut Self {
self.opts.socket = socket.map(Into::into);
self
}
}
impl From<OptsBuilder> for Opts {
fn from(builder: OptsBuilder) -> Opts {
Opts {
inner: Arc::new(builder.opts),
}
}
}
fn get_opts_user_from_url(url: &Url) -> Option<String> {
let user = url.username();
if user != "" {
Some(
percent_decode(user.as_ref())
.decode_utf8_lossy()
.into_owned(),
)
} else {
None
}
}
fn get_opts_pass_from_url(url: &Url) -> Option<String> {
if let Some(pass) = url.password() {
Some(
percent_decode(pass.as_ref())
.decode_utf8_lossy()
.into_owned(),
)
} else {
None
}
}
fn get_opts_db_name_from_url(url: &Url) -> Option<String> {
if let Some(mut segments) = url.path_segments() {
segments.next().map(|db_name| {
percent_decode(db_name.as_ref())
.decode_utf8_lossy()
.into_owned()
})
} else {
None
}
}
fn from_url_basic(
url_str: &str,
) -> std::result::Result<(InnerOpts, Vec<(String, String)>), UrlError> {
let url = Url::parse(url_str)?;
if url.scheme() != "mysql" {
return Err(UrlError::UnsupportedScheme {
scheme: url.scheme().to_string(),
});
}
if url.cannot_be_a_base() || !url.has_host() {
return Err(UrlError::Invalid);
}
let user = get_opts_user_from_url(&url);
let pass = get_opts_pass_from_url(&url);
let ip_or_hostname = url
.host_str()
.map(String::from)
.unwrap_or_else(|| "127.0.0.1".into());
let tcp_port = url.port().unwrap_or(3306);
let db_name = get_opts_db_name_from_url(&url);
let query_pairs = url.query_pairs().into_owned().collect();
let opts = InnerOpts {
user,
pass,
ip_or_hostname,
tcp_port,
db_name,
..InnerOpts::default()
};
Ok((opts, query_pairs))
}
fn from_url(url: &str) -> std::result::Result<InnerOpts, UrlError> {
let (mut opts, query_pairs): (InnerOpts, _) = from_url_basic(url)?;
let mut pool_min = DEFAULT_POOL_CONSTRAINTS.min;
let mut pool_max = DEFAULT_POOL_CONSTRAINTS.max;
for (key, value) in query_pairs {
if key == "pool_min" {
match usize::from_str(&*value) {
Ok(value) => pool_min = value,
_ => {
return Err(UrlError::InvalidParamValue {
param: "pool_min".into(),
value,
});
}
}
} else if key == "pool_max" {
match usize::from_str(&*value) {
Ok(value) => pool_max = value,
_ => {
return Err(UrlError::InvalidParamValue {
param: "pool_max".into(),
value,
});
}
}
} else if key == "conn_ttl" {
match u32::from_str(&*value) {
Ok(value) => opts.conn_ttl = Some(value),
_ => {
return Err(UrlError::InvalidParamValue {
param: "conn_ttl".into(),
value,
});
}
}
} else if key == "tcp_keepalive" {
match u32::from_str(&*value) {
Ok(value) => opts.tcp_keepalive = Some(value),
_ => {
return Err(UrlError::InvalidParamValue {
param: "tcp_keepalive_ms".into(),
value,
});
}
}
} else if key == "tcp_nodelay" {
match bool::from_str(&*value) {
Ok(value) => opts.tcp_nodelay = value,
_ => {
return Err(UrlError::InvalidParamValue {
param: "tcp_nodelay".into(),
value,
});
}
}
} else if key == "stmt_cache_size" {
match usize::from_str(&*value) {
Ok(stmt_cache_size) => {
opts.stmt_cache_size = stmt_cache_size;
}
_ => {
return Err(UrlError::InvalidParamValue {
param: "stmt_cache_size".into(),
value,
});
}
}
} else if key == "prefer_socket" {
match bool::from_str(&*value) {
Ok(prefer_socket) => {
opts.prefer_socket = prefer_socket;
}
_ => {
return Err(UrlError::InvalidParamValue {
param: "prefer_socket".into(),
value,
});
}
}
} else if key == "socket" {
opts.socket = Some(value)
} else {
return Err(UrlError::UnknownParameter { param: key });
}
}
if let Some(pool_constraints) = PoolConstraints::new(pool_min, pool_max) {
opts.pool_constraints = pool_constraints;
} else {
return Err(UrlError::InvalidPoolConstraints {
min: pool_min,
max: pool_max,
});
}
Ok(opts)
}
impl FromStr for Opts {
type Err = UrlError;
fn from_str(s: &str) -> std::result::Result<Self, <Self as FromStr>::Err> {
Opts::from_url(s)
}
}
impl<T: AsRef<str> + Sized> From<T> for Opts {
fn from(url: T) -> Opts {
Opts::from_url(url.as_ref()).unwrap()
}
}
#[cfg(test)]
mod test {
use super::{from_url, InnerOpts, Opts};
#[test]
fn should_convert_url_into_opts() {
let opts = "mysql://usr:pw@192.168.1.1:3309/dbname";
assert_eq!(
InnerOpts {
user: Some("usr".to_string()),
pass: Some("pw".to_string()),
ip_or_hostname: "192.168.1.1".to_string(),
tcp_port: 3309,
db_name: Some("dbname".to_string()),
..InnerOpts::default()
},
from_url(opts).unwrap(),
);
}
#[test]
#[should_panic]
fn should_panic_on_invalid_url() {
let opts = "42";
let _: Opts = opts.into();
}
#[test]
#[should_panic]
fn should_panic_on_invalid_scheme() {
let opts = "postgres://localhost";
let _: Opts = opts.into();
}
#[test]
#[should_panic]
fn should_panic_on_unknown_query_param() {
let opts = "mysql://localhost/foo?bar=baz";
let _: Opts = opts.into();
}
}