mod native_tls_opts;
mod rustls_opts;
#[cfg(feature = "native-tls-tls")]
pub use native_tls_opts::ClientIdentity;
#[cfg(feature = "rustls-tls")]
pub use rustls_opts::ClientIdentity;
use percent_encoding::percent_decode;
use rand::Rng;
use tokio::sync::OnceCell;
use url::{Host, Url};
use std::{
borrow::Cow,
fmt, io,
net::{IpAddr, Ipv4Addr, Ipv6Addr},
path::{Path, PathBuf},
str::FromStr,
sync::Arc,
time::{Duration, Instant},
vec,
};
use crate::{
consts::CapabilityFlags,
error::*,
local_infile_handler::{GlobalHandler, GlobalHandlerObject},
};
pub const DEFAULT_POOL_CONSTRAINTS: PoolConstraints = PoolConstraints { min: 10, max: 100 };
const_assert!(
_DEFAULT_POOL_CONSTRAINTS_ARE_CORRECT,
DEFAULT_POOL_CONSTRAINTS.min <= DEFAULT_POOL_CONSTRAINTS.max
&& 0 < DEFAULT_POOL_CONSTRAINTS.max,
);
pub const DEFAULT_STMT_CACHE_SIZE: usize = 32;
pub const DEFAULT_PORT: u16 = 3306;
pub const DEFAULT_INACTIVE_CONNECTION_TTL: Duration = Duration::from_secs(0);
pub const DEFAULT_TTL_CHECK_INTERVAL: Duration = Duration::from_secs(30);
#[derive(Clone, Eq, PartialEq, Debug)]
pub(crate) enum HostPortOrUrl {
HostPort {
host: String,
port: u16,
resolved_ips: Option<Vec<IpAddr>>,
},
Url(Url),
}
impl Default for HostPortOrUrl {
fn default() -> Self {
HostPortOrUrl::HostPort {
host: "127.0.0.1".to_string(),
port: DEFAULT_PORT,
resolved_ips: None,
}
}
}
impl HostPortOrUrl {
pub fn get_ip_or_hostname(&self) -> &str {
match self {
Self::HostPort { host, .. } => host,
Self::Url(url) => url.host_str().unwrap_or("127.0.0.1"),
}
}
pub fn get_tcp_port(&self) -> u16 {
match self {
Self::HostPort { port, .. } => *port,
Self::Url(url) => url.port().unwrap_or(DEFAULT_PORT),
}
}
pub fn get_resolved_ips(&self) -> &Option<Vec<IpAddr>> {
match self {
Self::HostPort { resolved_ips, .. } => resolved_ips,
Self::Url(_) => &None,
}
}
pub fn is_loopback(&self) -> bool {
match self {
Self::HostPort {
host, resolved_ips, ..
} => {
let v4addr: Option<Ipv4Addr> = FromStr::from_str(host).ok();
let v6addr: Option<Ipv6Addr> = FromStr::from_str(host).ok();
if resolved_ips
.as_ref()
.is_some_and(|s| s.iter().any(|ip| ip.is_loopback()))
{
true
} else if let Some(addr) = v4addr {
addr.is_loopback()
} else if let Some(addr) = v6addr {
addr.is_loopback()
} else {
host == "localhost"
}
}
Self::Url(url) => match url.host() {
Some(Host::Ipv4(ip)) => ip.is_loopback(),
Some(Host::Ipv6(ip)) => ip.is_loopback(),
Some(Host::Domain(s)) => s == "localhost",
_ => false,
},
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum PathOrBuf<'a> {
Path(Cow<'a, Path>),
Buf(Cow<'a, [u8]>),
}
impl<'a> PathOrBuf<'a> {
pub async fn read(&self) -> io::Result<Cow<[u8]>> {
match self {
PathOrBuf::Path(x) => tokio::fs::read(x.as_ref()).await.map(Cow::Owned),
PathOrBuf::Buf(x) => Ok(Cow::Borrowed(x.as_ref())),
}
}
pub fn borrow(&self) -> PathOrBuf<'_> {
match self {
PathOrBuf::Path(path) => PathOrBuf::Path(Cow::Borrowed(path.as_ref())),
PathOrBuf::Buf(data) => PathOrBuf::Buf(Cow::Borrowed(data.as_ref())),
}
}
}
impl From<PathBuf> for PathOrBuf<'static> {
fn from(value: PathBuf) -> Self {
Self::Path(Cow::Owned(value))
}
}
impl<'a> From<&'a Path> for PathOrBuf<'a> {
fn from(value: &'a Path) -> Self {
Self::Path(Cow::Borrowed(value))
}
}
impl From<Vec<u8>> for PathOrBuf<'static> {
fn from(value: Vec<u8>) -> Self {
Self::Buf(Cow::Owned(value))
}
}
impl<'a> From<&'a [u8]> for PathOrBuf<'a> {
fn from(value: &'a [u8]) -> Self {
Self::Buf(Cow::Borrowed(value))
}
}
#[derive(Debug, Clone, Eq, PartialEq, Hash, Default)]
pub struct SslOpts {
#[cfg(any(feature = "native-tls-tls", feature = "rustls-tls"))]
client_identity: Option<ClientIdentity>,
root_certs: Vec<PathOrBuf<'static>>,
disable_built_in_roots: bool,
skip_domain_validation: bool,
accept_invalid_certs: bool,
tls_hostname_override: Option<Cow<'static, str>>,
}
impl SslOpts {
#[cfg(any(feature = "native-tls-tls", feature = "rustls-tls"))]
pub fn with_client_identity(mut self, identity: Option<ClientIdentity>) -> Self {
self.client_identity = identity;
self
}
pub fn with_root_certs(mut self, root_certs: Vec<PathOrBuf<'static>>) -> Self {
self.root_certs = root_certs;
self
}
pub fn with_disable_built_in_roots(mut self, disable_built_in_roots: bool) -> Self {
self.disable_built_in_roots = disable_built_in_roots;
self
}
pub fn with_danger_skip_domain_validation(mut self, value: bool) -> Self {
self.skip_domain_validation = value;
self
}
pub fn with_danger_accept_invalid_certs(mut self, value: bool) -> Self {
self.accept_invalid_certs = value;
self
}
pub fn with_danger_tls_hostname_override<T: Into<Cow<'static, str>>>(
mut self,
domain: Option<T>,
) -> Self {
self.tls_hostname_override = domain.map(Into::into);
self
}
#[cfg(any(feature = "native-tls-tls", feature = "rustls-tls"))]
pub fn client_identity(&self) -> Option<&ClientIdentity> {
self.client_identity.as_ref()
}
pub fn root_certs(&self) -> &[PathOrBuf<'static>] {
&self.root_certs
}
pub fn disable_built_in_roots(&self) -> bool {
self.disable_built_in_roots
}
pub fn skip_domain_validation(&self) -> bool {
self.skip_domain_validation
}
pub fn accept_invalid_certs(&self) -> bool {
self.accept_invalid_certs
}
pub fn tls_hostname_override(&self) -> Option<&str> {
self.tls_hostname_override.as_deref()
}
}
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
pub struct PoolOpts {
constraints: PoolConstraints,
inactive_connection_ttl: Duration,
ttl_check_interval: Duration,
abs_conn_ttl: Option<Duration>,
abs_conn_ttl_jitter: Option<Duration>,
reset_connection: bool,
}
impl PoolOpts {
pub fn new() -> Self {
Self::default()
}
pub fn with_constraints(mut self, constraints: PoolConstraints) -> Self {
self.constraints = constraints;
self
}
pub fn constraints(&self) -> PoolConstraints {
self.constraints
}
pub fn with_reset_connection(mut self, reset_connection: bool) -> Self {
self.reset_connection = reset_connection;
self
}
pub fn reset_connection(&self) -> bool {
self.reset_connection
}
pub fn with_abs_conn_ttl(mut self, ttl: Option<Duration>) -> Self {
self.abs_conn_ttl = ttl;
self
}
pub fn with_abs_conn_ttl_jitter(mut self, jitter: Option<Duration>) -> Self {
self.abs_conn_ttl_jitter = jitter;
self
}
pub fn abs_conn_ttl(&self) -> Option<Duration> {
self.abs_conn_ttl
}
pub fn abs_conn_ttl_jitter(&self) -> Option<Duration> {
self.abs_conn_ttl_jitter
}
pub(crate) fn new_connection_ttl_deadline(&self) -> Option<Instant> {
if let Some(ttl) = self.abs_conn_ttl {
let jitter = if let Some(jitter) = self.abs_conn_ttl_jitter {
Duration::from_secs(rand::rng().random_range(0..=jitter.as_secs()))
} else {
Duration::ZERO
};
Some(Instant::now() + ttl + jitter)
} else {
None
}
}
pub fn with_inactive_connection_ttl(mut self, ttl: Duration) -> Self {
self.inactive_connection_ttl = ttl;
self
}
pub fn inactive_connection_ttl(&self) -> Duration {
self.inactive_connection_ttl
}
pub fn with_ttl_check_interval(mut self, interval: Duration) -> Self {
if interval < Duration::from_secs(1) {
self.ttl_check_interval = DEFAULT_TTL_CHECK_INTERVAL
} else {
self.ttl_check_interval = interval;
}
self
}
pub fn ttl_check_interval(&self) -> Duration {
self.ttl_check_interval
}
pub(crate) fn active_bound(&self) -> usize {
if self.inactive_connection_ttl > Duration::from_secs(0) {
self.constraints.max
} else {
self.constraints.min
}
}
}
impl Default for PoolOpts {
fn default() -> Self {
Self {
constraints: DEFAULT_POOL_CONSTRAINTS,
inactive_connection_ttl: DEFAULT_INACTIVE_CONNECTION_TTL,
ttl_check_interval: DEFAULT_TTL_CHECK_INTERVAL,
abs_conn_ttl: None,
abs_conn_ttl_jitter: None,
reset_connection: true,
}
}
}
#[derive(Clone, Eq, PartialEq, Default, Debug)]
pub(crate) struct InnerOpts {
mysql_opts: MysqlOpts,
address: HostPortOrUrl,
}
#[derive(Clone, Eq, PartialEq, Debug)]
pub(crate) struct MysqlOpts {
user: Option<String>,
pass: Option<String>,
db_name: Option<String>,
tcp_keepalive: Option<u32>,
tcp_nodelay: bool,
local_infile_handler: Option<GlobalHandlerObject>,
pool_opts: PoolOpts,
conn_ttl: Option<Duration>,
init: Vec<String>,
setup: Vec<String>,
stmt_cache_size: usize,
ssl_opts: Option<SslOptsAndCachedConnector>,
prefer_socket: bool,
socket: Option<String>,
compression: Option<crate::Compression>,
max_allowed_packet: Option<usize>,
wait_timeout: Option<usize>,
secure_auth: bool,
client_found_rows: bool,
enable_cleartext_plugin: bool,
}
#[derive(Clone, Eq, PartialEq, Debug, Default)]
pub struct Opts {
inner: Arc<InnerOpts>,
}
impl Opts {
#[doc(hidden)]
pub fn addr_is_loopback(&self) -> bool {
self.inner.address.is_loopback()
}
pub fn from_url(url: &str) -> std::result::Result<Opts, UrlError> {
let mut url = Url::parse(url)?;
if url.port().is_none() {
url.set_port(Some(DEFAULT_PORT))
.map_err(|_| UrlError::Invalid)?;
}
let mysql_opts = mysqlopts_from_url(&url)?;
let address = HostPortOrUrl::Url(url);
let inner_opts = InnerOpts {
mysql_opts,
address,
};
Ok(Opts {
inner: Arc::new(inner_opts),
})
}
pub fn ip_or_hostname(&self) -> &str {
self.inner.address.get_ip_or_hostname()
}
pub(crate) fn hostport_or_url(&self) -> &HostPortOrUrl {
&self.inner.address
}
pub fn tcp_port(&self) -> u16 {
self.inner.address.get_tcp_port()
}
pub fn resolved_ips(&self) -> &Option<Vec<IpAddr>> {
self.inner.address.get_resolved_ips()
}
pub fn user(&self) -> Option<&str> {
self.inner.mysql_opts.user.as_ref().map(AsRef::as_ref)
}
pub fn pass(&self) -> Option<&str> {
self.inner.mysql_opts.pass.as_ref().map(AsRef::as_ref)
}
pub fn db_name(&self) -> Option<&str> {
self.inner.mysql_opts.db_name.as_ref().map(AsRef::as_ref)
}
pub fn init(&self) -> &[String] {
self.inner.mysql_opts.init.as_ref()
}
pub fn setup(&self) -> &[String] {
self.inner.mysql_opts.setup.as_ref()
}
pub fn tcp_keepalive(&self) -> Option<u32> {
self.inner.mysql_opts.tcp_keepalive
}
pub fn tcp_nodelay(&self) -> bool {
self.inner.mysql_opts.tcp_nodelay
}
pub fn local_infile_handler(&self) -> Option<Arc<dyn GlobalHandler>> {
self.inner
.mysql_opts
.local_infile_handler
.as_ref()
.map(|x| x.clone_inner())
}
pub fn pool_opts(&self) -> &PoolOpts {
&self.inner.mysql_opts.pool_opts
}
pub fn conn_ttl(&self) -> Option<Duration> {
self.inner.mysql_opts.conn_ttl
}
pub fn abs_conn_ttl(&self) -> Option<Duration> {
self.inner.mysql_opts.pool_opts.abs_conn_ttl
}
pub fn abs_conn_ttl_jitter(&self) -> Option<Duration> {
self.inner.mysql_opts.pool_opts.abs_conn_ttl_jitter
}
pub fn stmt_cache_size(&self) -> usize {
self.inner.mysql_opts.stmt_cache_size
}
pub fn ssl_opts(&self) -> Option<&SslOpts> {
self.inner.mysql_opts.ssl_opts.as_ref().map(|o| &o.ssl_opts)
}
pub fn prefer_socket(&self) -> bool {
self.inner.mysql_opts.prefer_socket
}
pub fn socket(&self) -> Option<&str> {
self.inner.mysql_opts.socket.as_deref()
}
pub fn compression(&self) -> Option<crate::Compression> {
self.inner.mysql_opts.compression
}
pub fn max_allowed_packet(&self) -> Option<usize> {
self.inner.mysql_opts.max_allowed_packet
}
pub fn wait_timeout(&self) -> Option<usize> {
self.inner.mysql_opts.wait_timeout
}
pub fn secure_auth(&self) -> bool {
self.inner.mysql_opts.secure_auth
}
pub fn client_found_rows(&self) -> bool {
self.inner.mysql_opts.client_found_rows
}
pub fn enable_cleartext_plugin(&self) -> bool {
self.inner.mysql_opts.enable_cleartext_plugin
}
pub(crate) fn get_capabilities(&self) -> CapabilityFlags {
let mut out = CapabilityFlags::CLIENT_PROTOCOL_41
| CapabilityFlags::CLIENT_SECURE_CONNECTION
| CapabilityFlags::CLIENT_LONG_PASSWORD
| CapabilityFlags::CLIENT_TRANSACTIONS
| CapabilityFlags::CLIENT_LOCAL_FILES
| CapabilityFlags::CLIENT_MULTI_STATEMENTS
| CapabilityFlags::CLIENT_MULTI_RESULTS
| CapabilityFlags::CLIENT_PS_MULTI_RESULTS
| CapabilityFlags::CLIENT_DEPRECATE_EOF
| CapabilityFlags::CLIENT_PLUGIN_AUTH;
if self.inner.mysql_opts.db_name.is_some() {
out |= CapabilityFlags::CLIENT_CONNECT_WITH_DB;
}
if self.inner.mysql_opts.ssl_opts.is_some() {
out |= CapabilityFlags::CLIENT_SSL;
}
if self.inner.mysql_opts.compression.is_some() {
out |= CapabilityFlags::CLIENT_COMPRESS;
}
if self.client_found_rows() {
out |= CapabilityFlags::CLIENT_FOUND_ROWS;
}
out
}
pub(crate) fn ssl_opts_and_connector(&self) -> Option<&SslOptsAndCachedConnector> {
self.inner.mysql_opts.ssl_opts.as_ref()
}
}
impl Default for MysqlOpts {
fn default() -> MysqlOpts {
MysqlOpts {
user: None,
pass: None,
db_name: None,
init: vec![],
setup: vec![],
tcp_keepalive: None,
tcp_nodelay: true,
local_infile_handler: None,
pool_opts: Default::default(),
conn_ttl: None,
stmt_cache_size: DEFAULT_STMT_CACHE_SIZE,
ssl_opts: None,
prefer_socket: cfg!(not(target_os = "windows")),
socket: None,
compression: None,
max_allowed_packet: None,
wait_timeout: None,
secure_auth: true,
client_found_rows: false,
enable_cleartext_plugin: false,
}
}
}
#[derive(Clone)]
pub(crate) struct SslOptsAndCachedConnector {
ssl_opts: SslOpts,
tls_connector: Arc<OnceCell<crate::io::TlsConnector>>,
}
impl fmt::Debug for SslOptsAndCachedConnector {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SslOptsAndCachedConnector")
.field("ssl_opts", &self.ssl_opts)
.finish()
}
}
impl SslOptsAndCachedConnector {
fn new(ssl_opts: SslOpts) -> Self {
Self {
ssl_opts,
tls_connector: Arc::new(OnceCell::new()),
}
}
pub(crate) fn ssl_opts(&self) -> &SslOpts {
&self.ssl_opts
}
pub(crate) async fn build_tls_connector(&self) -> Result<crate::io::TlsConnector> {
self.tls_connector
.get_or_try_init(move || self.ssl_opts.build_tls_connector())
.await
.cloned()
}
}
impl PartialEq for SslOptsAndCachedConnector {
fn eq(&self, other: &Self) -> bool {
self.ssl_opts == other.ssl_opts
}
}
impl Eq for SslOptsAndCachedConnector {}
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
pub struct PoolConstraints {
min: usize,
max: usize,
}
impl PoolConstraints {
pub const fn new(min: usize, max: usize) -> Option<PoolConstraints> {
if min <= max && max > 0 {
Some(PoolConstraints { min, max })
} else {
None
}
}
pub fn min(&self) -> usize {
self.min
}
pub fn max(&self) -> usize {
self.max
}
}
impl Default for PoolConstraints {
fn default() -> Self {
DEFAULT_POOL_CONSTRAINTS
}
}
impl From<PoolConstraints> for (usize, usize) {
fn from(PoolConstraints { min, max }: PoolConstraints) -> Self {
(min, max)
}
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct OptsBuilder {
opts: MysqlOpts,
ip_or_hostname: String,
tcp_port: u16,
resolved_ips: Option<Vec<IpAddr>>,
}
impl Default for OptsBuilder {
fn default() -> Self {
let address = HostPortOrUrl::default();
Self {
opts: MysqlOpts::default(),
ip_or_hostname: address.get_ip_or_hostname().into(),
tcp_port: address.get_tcp_port(),
resolved_ips: None,
}
}
}
impl OptsBuilder {
pub fn from_opts<T>(opts: T) -> Self
where
Opts: TryFrom<T>,
<Opts as TryFrom<T>>::Error: std::error::Error,
{
let opts = Opts::try_from(opts).unwrap();
OptsBuilder {
tcp_port: opts.inner.address.get_tcp_port(),
ip_or_hostname: opts.inner.address.get_ip_or_hostname().to_string(),
resolved_ips: opts.inner.address.get_resolved_ips().clone(),
opts: opts.inner.mysql_opts.clone(),
}
}
pub fn ip_or_hostname<T: Into<String>>(mut self, ip_or_hostname: T) -> Self {
self.ip_or_hostname = ip_or_hostname.into();
self
}
pub fn tcp_port(mut self, tcp_port: u16) -> Self {
self.tcp_port = tcp_port;
self
}
pub fn resolved_ips<T: Into<Vec<IpAddr>>>(mut self, ips: Option<T>) -> Self {
self.resolved_ips = ips.map(Into::into);
self
}
pub fn user<T: Into<String>>(mut self, user: Option<T>) -> Self {
self.opts.user = user.map(Into::into);
self
}
pub fn pass<T: Into<String>>(mut self, pass: Option<T>) -> Self {
self.opts.pass = pass.map(Into::into);
self
}
pub fn db_name<T: Into<String>>(mut self, db_name: Option<T>) -> Self {
self.opts.db_name = db_name.map(Into::into);
self
}
pub fn init<T: Into<String>>(mut self, init: Vec<T>) -> Self {
self.opts.init = init.into_iter().map(Into::into).collect();
self
}
pub fn setup<T: Into<String>>(mut self, setup: Vec<T>) -> Self {
self.opts.setup = setup.into_iter().map(Into::into).collect();
self
}
pub fn tcp_keepalive<T: Into<u32>>(mut self, tcp_keepalive: Option<T>) -> Self {
self.opts.tcp_keepalive = tcp_keepalive.map(Into::into);
self
}
pub fn tcp_nodelay(mut self, nodelay: bool) -> Self {
self.opts.tcp_nodelay = nodelay;
self
}
pub fn local_infile_handler<T>(mut self, handler: Option<T>) -> Self
where
T: GlobalHandler,
{
self.opts.local_infile_handler = handler.map(GlobalHandlerObject::new);
self
}
pub fn pool_opts<T: Into<Option<PoolOpts>>>(mut self, pool_opts: T) -> Self {
self.opts.pool_opts = pool_opts.into().unwrap_or_default();
self
}
pub fn conn_ttl<T: Into<Option<Duration>>>(mut self, conn_ttl: T) -> Self {
self.opts.conn_ttl = conn_ttl.into();
self
}
pub fn stmt_cache_size<T>(mut self, cache_size: T) -> Self
where
T: Into<Option<usize>>,
{
self.opts.stmt_cache_size = cache_size.into().unwrap_or(DEFAULT_STMT_CACHE_SIZE);
self
}
pub fn ssl_opts<T: Into<Option<SslOpts>>>(mut self, ssl_opts: T) -> Self {
self.opts.ssl_opts = ssl_opts.into().map(SslOptsAndCachedConnector::new);
self
}
pub fn prefer_socket<T: Into<Option<bool>>>(mut self, prefer_socket: T) -> Self {
self.opts.prefer_socket = prefer_socket.into().unwrap_or(true);
self
}
pub fn socket<T: Into<String>>(mut self, socket: Option<T>) -> Self {
self.opts.socket = socket.map(Into::into);
self
}
pub fn compression<T: Into<Option<crate::Compression>>>(mut self, compression: T) -> Self {
self.opts.compression = compression.into();
self
}
pub fn max_allowed_packet(mut self, max_allowed_packet: Option<usize>) -> Self {
self.opts.max_allowed_packet = max_allowed_packet.map(|x| x.clamp(1024, 1073741824));
self
}
pub fn wait_timeout(mut self, wait_timeout: Option<usize>) -> Self {
self.opts.wait_timeout = wait_timeout.map(|x| {
#[cfg(windows)]
let val = std::cmp::min(2147483, x);
#[cfg(not(windows))]
let val = std::cmp::min(31536000, x);
val
});
self
}
pub fn secure_auth(mut self, secure_auth: bool) -> Self {
self.opts.secure_auth = secure_auth;
self
}
pub fn client_found_rows(mut self, client_found_rows: bool) -> Self {
self.opts.client_found_rows = client_found_rows;
self
}
pub fn enable_cleartext_plugin(mut self, enable_cleartext_plugin: bool) -> Self {
self.opts.enable_cleartext_plugin = enable_cleartext_plugin;
self
}
}
impl From<OptsBuilder> for Opts {
fn from(builder: OptsBuilder) -> Opts {
let address = HostPortOrUrl::HostPort {
host: builder.ip_or_hostname,
port: builder.tcp_port,
resolved_ips: builder.resolved_ips,
};
let inner_opts = InnerOpts {
mysql_opts: builder.opts,
address,
};
Opts {
inner: Arc::new(inner_opts),
}
}
}
#[derive(Clone, Eq, PartialEq)]
pub struct ChangeUserOpts {
user: Option<Option<String>>,
pass: Option<Option<String>>,
db_name: Option<Option<String>>,
}
impl ChangeUserOpts {
pub(crate) fn update_opts(self, opts: &mut Opts) {
if self.user.is_none() && self.pass.is_none() && self.db_name.is_none() {
return;
}
let mut builder = OptsBuilder::from_opts(opts.clone());
if let Some(user) = self.user {
builder = builder.user(user);
}
if let Some(pass) = self.pass {
builder = builder.pass(pass);
}
if let Some(db_name) = self.db_name {
builder = builder.db_name(db_name);
}
*opts = Opts::from(builder);
}
pub fn new() -> Self {
Self {
user: None,
pass: None,
db_name: None,
}
}
pub fn with_user(mut self, user: Option<String>) -> Self {
self.user = Some(user);
self
}
pub fn with_pass(mut self, pass: Option<String>) -> Self {
self.pass = Some(pass);
self
}
pub fn with_db_name(mut self, db_name: Option<String>) -> Self {
self.db_name = Some(db_name);
self
}
pub fn user(&self) -> Option<Option<&str>> {
self.user.as_ref().map(|x| x.as_deref())
}
pub fn pass(&self) -> Option<Option<&str>> {
self.pass.as_ref().map(|x| x.as_deref())
}
pub fn db_name(&self) -> Option<Option<&str>> {
self.db_name.as_ref().map(|x| x.as_deref())
}
}
impl Default for ChangeUserOpts {
fn default() -> Self {
Self::new()
}
}
impl fmt::Debug for ChangeUserOpts {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ChangeUserOpts")
.field("user", &self.user)
.field(
"pass",
&self.pass.as_ref().map(|x| x.as_ref().map(|_| "...")),
)
.field("db_name", &self.db_name)
.finish()
}
}
fn get_opts_user_from_url(url: &Url) -> Option<String> {
let user = url.username();
if !user.is_empty() {
Some(
percent_decode(user.as_ref())
.decode_utf8_lossy()
.into_owned(),
)
} else {
None
}
}
fn get_opts_pass_from_url(url: &Url) -> Option<String> {
url.password().map(|pass| {
percent_decode(pass.as_ref())
.decode_utf8_lossy()
.into_owned()
})
}
fn get_opts_db_name_from_url(url: &Url) -> Option<String> {
if let Some(mut segments) = url.path_segments() {
segments
.next()
.map(|db_name| {
percent_decode(db_name.as_ref())
.decode_utf8_lossy()
.into_owned()
})
.filter(|db| !db.is_empty())
} else {
None
}
}
fn from_url_basic(url: &Url) -> std::result::Result<(MysqlOpts, Vec<(String, String)>), UrlError> {
if url.scheme() != "mysql" {
return Err(UrlError::UnsupportedScheme {
scheme: url.scheme().to_string(),
});
}
if url.cannot_be_a_base() || !url.has_host() {
return Err(UrlError::Invalid);
}
let user = get_opts_user_from_url(url);
let pass = get_opts_pass_from_url(url);
let db_name = get_opts_db_name_from_url(url);
let query_pairs = url.query_pairs().into_owned().collect();
let opts = MysqlOpts {
user,
pass,
db_name,
..MysqlOpts::default()
};
Ok((opts, query_pairs))
}
fn mysqlopts_from_url(url: &Url) -> std::result::Result<MysqlOpts, UrlError> {
let (mut opts, query_pairs): (MysqlOpts, _) = from_url_basic(url)?;
let mut pool_min = DEFAULT_POOL_CONSTRAINTS.min;
let mut pool_max = DEFAULT_POOL_CONSTRAINTS.max;
let mut ssl_opts = None;
let mut skip_domain_validation = false;
let mut accept_invalid_certs = false;
let mut disable_built_in_roots = false;
for (key, value) in query_pairs {
if key == "pool_min" {
match usize::from_str(&value) {
Ok(value) => pool_min = value,
_ => {
return Err(UrlError::InvalidParamValue {
param: "pool_min".into(),
value,
});
}
}
} else if key == "pool_max" {
match usize::from_str(&value) {
Ok(value) => pool_max = value,
_ => {
return Err(UrlError::InvalidParamValue {
param: "pool_max".into(),
value,
});
}
}
} else if key == "inactive_connection_ttl" {
match u64::from_str(&value) {
Ok(value) => {
opts.pool_opts = opts
.pool_opts
.with_inactive_connection_ttl(Duration::from_secs(value))
}
_ => {
return Err(UrlError::InvalidParamValue {
param: "inactive_connection_ttl".into(),
value,
});
}
}
} else if key == "ttl_check_interval" {
match u64::from_str(&value) {
Ok(value) => {
opts.pool_opts = opts
.pool_opts
.with_ttl_check_interval(Duration::from_secs(value))
}
_ => {
return Err(UrlError::InvalidParamValue {
param: "ttl_check_interval".into(),
value,
});
}
}
} else if key == "conn_ttl" {
match u64::from_str(&value) {
Ok(value) => opts.conn_ttl = Some(Duration::from_secs(value)),
_ => {
return Err(UrlError::InvalidParamValue {
param: "conn_ttl".into(),
value,
});
}
}
} else if key == "abs_conn_ttl" {
match u64::from_str(&value) {
Ok(value) => {
opts.pool_opts = opts
.pool_opts
.with_abs_conn_ttl(Some(Duration::from_secs(value)))
}
_ => {
return Err(UrlError::InvalidParamValue {
param: "abs_conn_ttl".into(),
value,
});
}
}
} else if key == "abs_conn_ttl_jitter" {
match u64::from_str(&value) {
Ok(value) => {
opts.pool_opts = opts
.pool_opts
.with_abs_conn_ttl_jitter(Some(Duration::from_secs(value)))
}
_ => {
return Err(UrlError::InvalidParamValue {
param: "abs_conn_ttl_jitter".into(),
value,
});
}
}
} else if key == "tcp_keepalive" {
match u32::from_str(&value) {
Ok(value) => opts.tcp_keepalive = Some(value),
_ => {
return Err(UrlError::InvalidParamValue {
param: "tcp_keepalive_ms".into(),
value,
});
}
}
} else if key == "max_allowed_packet" {
match usize::from_str(&value) {
Ok(value) => opts.max_allowed_packet = Some(value.clamp(1024, 1073741824)),
_ => {
return Err(UrlError::InvalidParamValue {
param: "max_allowed_packet".into(),
value,
});
}
}
} else if key == "wait_timeout" {
match usize::from_str(&value) {
#[cfg(windows)]
Ok(value) => opts.wait_timeout = Some(std::cmp::min(2147483, value)),
#[cfg(not(windows))]
Ok(value) => opts.wait_timeout = Some(std::cmp::min(31536000, value)),
_ => {
return Err(UrlError::InvalidParamValue {
param: "wait_timeout".into(),
value,
});
}
}
} else if key == "enable_cleartext_plugin" {
match bool::from_str(&value) {
Ok(parsed) => opts.enable_cleartext_plugin = parsed,
Err(_) => {
return Err(UrlError::InvalidParamValue {
param: key.to_string(),
value,
});
}
}
} else if key == "reset_connection" {
match bool::from_str(&value) {
Ok(parsed) => opts.pool_opts = opts.pool_opts.with_reset_connection(parsed),
Err(_) => {
return Err(UrlError::InvalidParamValue {
param: key.to_string(),
value,
});
}
}
} else if key == "tcp_nodelay" {
match bool::from_str(&value) {
Ok(value) => opts.tcp_nodelay = value,
_ => {
return Err(UrlError::InvalidParamValue {
param: "tcp_nodelay".into(),
value,
});
}
}
} else if key == "stmt_cache_size" {
match usize::from_str(&value) {
Ok(stmt_cache_size) => {
opts.stmt_cache_size = stmt_cache_size;
}
_ => {
return Err(UrlError::InvalidParamValue {
param: "stmt_cache_size".into(),
value,
});
}
}
} else if key == "prefer_socket" {
match bool::from_str(&value) {
Ok(prefer_socket) => {
opts.prefer_socket = prefer_socket;
}
_ => {
return Err(UrlError::InvalidParamValue {
param: "prefer_socket".into(),
value,
});
}
}
} else if key == "secure_auth" {
match bool::from_str(&value) {
Ok(secure_auth) => {
opts.secure_auth = secure_auth;
}
_ => {
return Err(UrlError::InvalidParamValue {
param: "secure_auth".into(),
value,
});
}
}
} else if key == "client_found_rows" {
match bool::from_str(&value) {
Ok(client_found_rows) => {
opts.client_found_rows = client_found_rows;
}
_ => {
return Err(UrlError::InvalidParamValue {
param: "client_found_rows".into(),
value,
});
}
}
} else if key == "socket" {
opts.socket = Some(value)
} else if key == "compression" {
if value == "fast" {
opts.compression = Some(crate::Compression::fast());
} else if value == "on" || value == "true" {
opts.compression = Some(crate::Compression::default());
} else if value == "best" {
opts.compression = Some(crate::Compression::best());
} else if value.len() == 1 && 0x30 <= value.as_bytes()[0] && value.as_bytes()[0] <= 0x39
{
opts.compression =
Some(crate::Compression::new((value.as_bytes()[0] - 0x30) as u32));
} else {
return Err(UrlError::InvalidParamValue {
param: "compression".into(),
value,
});
}
} else if key == "require_ssl" {
match bool::from_str(&value) {
Ok(x) => {
ssl_opts = x.then(SslOpts::default);
}
_ => {
return Err(UrlError::InvalidParamValue {
param: "require_ssl".into(),
value,
});
}
}
} else if key == "verify_ca" {
match bool::from_str(&value) {
Ok(x) => {
accept_invalid_certs = !x;
}
_ => {
return Err(UrlError::InvalidParamValue {
param: "verify_ca".into(),
value,
});
}
}
} else if key == "verify_identity" {
match bool::from_str(&value) {
Ok(x) => {
skip_domain_validation = !x;
}
_ => {
return Err(UrlError::InvalidParamValue {
param: "verify_identity".into(),
value,
});
}
}
} else if key == "built_in_roots" {
match bool::from_str(&value) {
Ok(x) => {
disable_built_in_roots = !x;
}
_ => {
return Err(UrlError::InvalidParamValue {
param: "built_in_roots".into(),
value,
});
}
}
} else {
return Err(UrlError::UnknownParameter { param: key });
}
}
if let Some(pool_constraints) = PoolConstraints::new(pool_min, pool_max) {
opts.pool_opts = opts.pool_opts.with_constraints(pool_constraints);
} else {
return Err(UrlError::InvalidPoolConstraints {
min: pool_min,
max: pool_max,
});
}
if let Some(ref mut ssl_opts) = ssl_opts {
ssl_opts.accept_invalid_certs = accept_invalid_certs;
ssl_opts.skip_domain_validation = skip_domain_validation;
ssl_opts.disable_built_in_roots = disable_built_in_roots;
}
opts.ssl_opts = ssl_opts.map(SslOptsAndCachedConnector::new);
Ok(opts)
}
impl FromStr for Opts {
type Err = UrlError;
fn from_str(s: &str) -> std::result::Result<Self, <Self as FromStr>::Err> {
Opts::from_url(s)
}
}
impl TryFrom<&str> for Opts {
type Error = UrlError;
fn try_from(s: &str) -> std::result::Result<Self, UrlError> {
Opts::from_url(s)
}
}
#[cfg(test)]
mod test {
use super::{HostPortOrUrl, MysqlOpts, Opts, Url};
use crate::{error::UrlError::InvalidParamValue, SslOpts};
use std::{net::IpAddr, net::Ipv4Addr, net::Ipv6Addr, str::FromStr};
#[test]
fn test_builder_eq_url() {
const URL: &str = "mysql://iq-controller@localhost/iq_controller";
let url_opts = super::Opts::from_str(URL).unwrap();
let builder = super::OptsBuilder::default()
.user(Some("iq-controller"))
.ip_or_hostname("localhost")
.db_name(Some("iq_controller"));
let builder_opts = Opts::from(builder);
assert_eq!(url_opts.addr_is_loopback(), builder_opts.addr_is_loopback());
assert_eq!(url_opts.ip_or_hostname(), builder_opts.ip_or_hostname());
assert_eq!(url_opts.tcp_port(), builder_opts.tcp_port());
assert_eq!(url_opts.user(), builder_opts.user());
assert_eq!(url_opts.pass(), builder_opts.pass());
assert_eq!(url_opts.db_name(), builder_opts.db_name());
assert_eq!(url_opts.init(), builder_opts.init());
assert_eq!(url_opts.setup(), builder_opts.setup());
assert_eq!(url_opts.tcp_keepalive(), builder_opts.tcp_keepalive());
assert_eq!(url_opts.tcp_nodelay(), builder_opts.tcp_nodelay());
assert_eq!(url_opts.pool_opts(), builder_opts.pool_opts());
assert_eq!(url_opts.conn_ttl(), builder_opts.conn_ttl());
assert_eq!(url_opts.abs_conn_ttl(), builder_opts.abs_conn_ttl());
assert_eq!(
url_opts.abs_conn_ttl_jitter(),
builder_opts.abs_conn_ttl_jitter()
);
assert_eq!(url_opts.stmt_cache_size(), builder_opts.stmt_cache_size());
assert_eq!(url_opts.ssl_opts(), builder_opts.ssl_opts());
assert_eq!(url_opts.prefer_socket(), builder_opts.prefer_socket());
assert_eq!(url_opts.socket(), builder_opts.socket());
assert_eq!(url_opts.compression(), builder_opts.compression());
assert_eq!(
url_opts.hostport_or_url().get_ip_or_hostname(),
builder_opts.hostport_or_url().get_ip_or_hostname()
);
assert_eq!(
url_opts.hostport_or_url().get_tcp_port(),
builder_opts.hostport_or_url().get_tcp_port()
);
}
#[test]
fn should_convert_url_into_opts() {
let url = "mysql://usr:pw@192.168.1.1:3309/dbname?prefer_socket=true";
let parsed_url =
Url::parse("mysql://usr:pw@192.168.1.1:3309/dbname?prefer_socket=true").unwrap();
let mysql_opts = MysqlOpts {
user: Some("usr".to_string()),
pass: Some("pw".to_string()),
db_name: Some("dbname".to_string()),
prefer_socket: true,
..MysqlOpts::default()
};
let host = HostPortOrUrl::Url(parsed_url);
let opts = Opts::from_url(url).unwrap();
assert_eq!(opts.inner.mysql_opts, mysql_opts);
assert_eq!(opts.hostport_or_url(), &host);
}
#[test]
fn should_convert_ipv6_url_into_opts() {
let url = "mysql://usr:pw@[::1]:3309/dbname";
let opts = Opts::from_url(url).unwrap();
assert_eq!(opts.ip_or_hostname(), "[::1]");
}
#[test]
fn should_parse_ssl_params() {
const URL1: &str = "mysql://localhost/foo?require_ssl=false";
let opts = Opts::from_url(URL1).unwrap();
assert_eq!(opts.ssl_opts(), None);
const URL2: &str = "mysql://localhost/foo?require_ssl=true";
let opts = Opts::from_url(URL2).unwrap();
assert_eq!(opts.ssl_opts(), Some(&SslOpts::default()));
const URL3: &str = "mysql://localhost/foo?require_ssl=true&verify_ca=false";
let opts = Opts::from_url(URL3).unwrap();
assert_eq!(
opts.ssl_opts(),
Some(&SslOpts::default().with_danger_accept_invalid_certs(true))
);
const URL4: &str =
"mysql://localhost/foo?require_ssl=true&verify_ca=false&verify_identity=false&built_in_roots=false";
let opts = Opts::from_url(URL4).unwrap();
assert_eq!(
opts.ssl_opts(),
Some(
&SslOpts::default()
.with_danger_accept_invalid_certs(true)
.with_danger_skip_domain_validation(true)
.with_disable_built_in_roots(true)
)
);
const URL5: &str =
"mysql://localhost/foo?require_ssl=false&verify_ca=false&verify_identity=false";
let opts = Opts::from_url(URL5).unwrap();
assert_eq!(opts.ssl_opts(), None);
}
#[test]
#[should_panic]
fn should_panic_on_invalid_url() {
let opts = "42";
let _: Opts = Opts::from_str(opts).unwrap();
}
#[test]
#[should_panic]
fn should_panic_on_invalid_scheme() {
let opts = "postgres://localhost";
let _: Opts = Opts::from_str(opts).unwrap();
}
#[test]
#[should_panic]
fn should_panic_on_unknown_query_param() {
let opts = "mysql://localhost/foo?bar=baz";
let _: Opts = Opts::from_str(opts).unwrap();
}
#[test]
fn should_parse_compression() {
let err = Opts::from_url("mysql://localhost/foo?compression=").unwrap_err();
assert_eq!(
err,
InvalidParamValue {
param: "compression".into(),
value: "".into()
}
);
let err = Opts::from_url("mysql://localhost/foo?compression=a").unwrap_err();
assert_eq!(
err,
InvalidParamValue {
param: "compression".into(),
value: "a".into()
}
);
let opts = Opts::from_url("mysql://localhost/foo?compression=fast").unwrap();
assert_eq!(opts.compression(), Some(crate::Compression::fast()));
let opts = Opts::from_url("mysql://localhost/foo?compression=on").unwrap();
assert_eq!(opts.compression(), Some(crate::Compression::default()));
let opts = Opts::from_url("mysql://localhost/foo?compression=true").unwrap();
assert_eq!(opts.compression(), Some(crate::Compression::default()));
let opts = Opts::from_url("mysql://localhost/foo?compression=best").unwrap();
assert_eq!(opts.compression(), Some(crate::Compression::best()));
let opts = Opts::from_url("mysql://localhost/foo?compression=0").unwrap();
assert_eq!(opts.compression(), Some(crate::Compression::new(0)));
let opts = Opts::from_url("mysql://localhost/foo?compression=9").unwrap();
assert_eq!(opts.compression(), Some(crate::Compression::new(9)));
}
#[test]
fn test_builder_eq_url_empty_db() {
let builder = super::OptsBuilder::default();
let builder_opts = Opts::from(builder);
let url: &str = "mysql://iq-controller@localhost";
let url_opts = super::Opts::from_str(url).unwrap();
assert_eq!(url_opts.db_name(), builder_opts.db_name());
let url: &str = "mysql://iq-controller@localhost/";
let url_opts = super::Opts::from_str(url).unwrap();
assert_eq!(url_opts.db_name(), builder_opts.db_name());
}
#[test]
fn test_builder_update_port_host_resolved_ips() {
let builder = super::OptsBuilder::default()
.ip_or_hostname("foo")
.tcp_port(33306);
let resolved = vec![
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 7)),
IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0xc00a, 0x2ff)),
];
let builder2 = builder
.clone()
.tcp_port(55223)
.resolved_ips(Some(resolved.clone()));
let builder_opts = Opts::from(builder);
assert_eq!(builder_opts.ip_or_hostname(), "foo");
assert_eq!(builder_opts.tcp_port(), 33306);
assert_eq!(
builder_opts.hostport_or_url(),
&HostPortOrUrl::HostPort {
host: "foo".to_string(),
port: 33306,
resolved_ips: None
}
);
let builder_opts2 = Opts::from(builder2);
assert_eq!(builder_opts2.ip_or_hostname(), "foo");
assert_eq!(builder_opts2.tcp_port(), 55223);
assert_eq!(
builder_opts2.hostport_or_url(),
&HostPortOrUrl::HostPort {
host: "foo".to_string(),
port: 55223,
resolved_ips: Some(resolved),
}
);
}
}