use std::collections::HashMap;
use std::path::PathBuf;
use std::time::Duration;
#[derive(Debug, Clone, Default)]
pub struct TlsConfig {
pub ca_cert_path: Option<PathBuf>,
pub client_cert_path: Option<PathBuf>,
pub client_key_path: Option<PathBuf>,
pub danger_skip_verify: bool,
pub server_name: Option<String>,
}
impl TlsConfig {
pub fn new() -> Self {
Self::default()
}
pub fn ca_cert(mut self, path: impl Into<PathBuf>) -> Self {
self.ca_cert_path = Some(path.into());
self
}
pub fn client_cert(mut self, path: impl Into<PathBuf>) -> Self {
self.client_cert_path = Some(path.into());
self
}
pub fn client_key(mut self, path: impl Into<PathBuf>) -> Self {
self.client_key_path = Some(path.into());
self
}
pub fn skip_verify(mut self, skip: bool) -> Self {
self.danger_skip_verify = skip;
self
}
pub fn server_name(mut self, name: impl Into<String>) -> Self {
self.server_name = Some(name.into());
self
}
pub fn has_client_cert(&self) -> bool {
self.client_cert_path.is_some() && self.client_key_path.is_some()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum SslMode {
#[default]
Disable,
Preferred,
Required,
VerifyCa,
VerifyIdentity,
}
impl SslMode {
pub const fn should_try_ssl(self) -> bool {
!matches!(self, SslMode::Disable)
}
pub const fn is_required(self) -> bool {
matches!(
self,
SslMode::Required | SslMode::VerifyCa | SslMode::VerifyIdentity
)
}
}
#[derive(Debug, Clone)]
pub struct MySqlConfig {
pub host: String,
pub port: u16,
pub user: String,
pub password: Option<String>,
pub database: Option<String>,
pub charset: u8,
pub connect_timeout: Duration,
pub ssl_mode: SslMode,
pub tls_config: TlsConfig,
pub compression: bool,
pub attributes: HashMap<String, String>,
pub local_infile: bool,
pub max_packet_size: u32,
}
impl Default for MySqlConfig {
fn default() -> Self {
Self {
host: "localhost".to_string(),
port: 3306,
user: String::new(),
password: None,
database: None,
charset: crate::protocol::charset::UTF8MB4_0900_AI_CI,
connect_timeout: Duration::from_secs(30),
ssl_mode: SslMode::default(),
tls_config: TlsConfig::default(),
compression: false,
attributes: HashMap::new(),
local_infile: false,
max_packet_size: 64 * 1024 * 1024, }
}
}
impl MySqlConfig {
pub fn new() -> Self {
Self::default()
}
pub fn host(mut self, host: impl Into<String>) -> Self {
self.host = host.into();
self
}
pub fn port(mut self, port: u16) -> Self {
self.port = port;
self
}
pub fn user(mut self, user: impl Into<String>) -> Self {
self.user = user.into();
self
}
pub fn password(mut self, password: impl Into<String>) -> Self {
self.password.replace(password.into());
self
}
pub(crate) fn password_str(&self) -> &str {
self.password.as_deref().unwrap_or_default()
}
pub(crate) fn password_owned(&self) -> String {
self.password.clone().unwrap_or_default()
}
pub fn database(mut self, database: impl Into<String>) -> Self {
self.database = Some(database.into());
self
}
pub fn charset(mut self, charset: u8) -> Self {
self.charset = charset;
self
}
pub fn connect_timeout(mut self, timeout: Duration) -> Self {
self.connect_timeout = timeout;
self
}
pub fn ssl_mode(mut self, mode: SslMode) -> Self {
self.ssl_mode = mode;
self
}
pub fn tls_config(mut self, config: TlsConfig) -> Self {
self.tls_config = config;
self
}
pub fn ca_cert(mut self, path: impl Into<PathBuf>) -> Self {
self.tls_config.ca_cert_path = Some(path.into());
self
}
pub fn client_cert(
mut self,
cert_path: impl Into<PathBuf>,
key_path: impl Into<PathBuf>,
) -> Self {
self.tls_config.client_cert_path = Some(cert_path.into());
self.tls_config.client_key_path = Some(key_path.into());
self
}
pub fn compression(mut self, enabled: bool) -> Self {
self.compression = enabled;
self
}
pub fn attribute(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.attributes.insert(key.into(), value.into());
self
}
pub fn local_infile(mut self, enabled: bool) -> Self {
self.local_infile = enabled;
self
}
pub fn max_packet_size(mut self, size: u32) -> Self {
self.max_packet_size = size;
self
}
pub fn socket_addr(&self) -> String {
format!("{}:{}", self.host, self.port)
}
pub fn capability_flags(&self) -> u32 {
use crate::protocol::capabilities::{
CLIENT_COMPRESS, CLIENT_CONNECT_ATTRS, CLIENT_CONNECT_WITH_DB, CLIENT_LOCAL_FILES,
CLIENT_SSL, DEFAULT_CLIENT_FLAGS,
};
let mut flags = DEFAULT_CLIENT_FLAGS;
if self.database.is_some() {
flags |= CLIENT_CONNECT_WITH_DB;
}
if self.ssl_mode.should_try_ssl() {
flags |= CLIENT_SSL;
}
if self.compression {
flags |= CLIENT_COMPRESS;
}
if self.local_infile {
flags |= CLIENT_LOCAL_FILES;
}
if !self.attributes.is_empty() {
flags |= CLIENT_CONNECT_ATTRS;
}
flags
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_builder() {
let config = MySqlConfig::new()
.host("db.example.com")
.port(3307)
.user("myuser")
.password("test")
.database("testdb")
.connect_timeout(Duration::from_secs(10))
.ssl_mode(SslMode::Required)
.compression(true)
.attribute("program_name", "myapp");
assert_eq!(config.host, "db.example.com");
assert_eq!(config.port, 3307);
assert_eq!(config.user, "myuser");
assert_eq!(config.password, Some("test".to_string()));
assert_eq!(config.database, Some("testdb".to_string()));
assert_eq!(config.connect_timeout, Duration::from_secs(10));
assert_eq!(config.ssl_mode, SslMode::Required);
assert!(config.compression);
assert_eq!(
config.attributes.get("program_name"),
Some(&"myapp".to_string())
);
}
#[test]
fn test_socket_addr() {
let config = MySqlConfig::new().host("db.example.com").port(3307);
assert_eq!(config.socket_addr(), "db.example.com:3307");
}
#[test]
fn test_ssl_mode_properties() {
assert!(!SslMode::Disable.should_try_ssl());
assert!(!SslMode::Disable.is_required());
assert!(SslMode::Preferred.should_try_ssl());
assert!(!SslMode::Preferred.is_required());
assert!(SslMode::Required.should_try_ssl());
assert!(SslMode::Required.is_required());
assert!(SslMode::VerifyCa.should_try_ssl());
assert!(SslMode::VerifyCa.is_required());
assert!(SslMode::VerifyIdentity.should_try_ssl());
assert!(SslMode::VerifyIdentity.is_required());
}
#[test]
fn test_capability_flags() {
use crate::protocol::capabilities::*;
let config = MySqlConfig::new().database("test").compression(true);
let flags = config.capability_flags();
assert!(flags & CLIENT_CONNECT_WITH_DB != 0);
assert!(flags & CLIENT_COMPRESS != 0);
assert!(flags & CLIENT_PROTOCOL_41 != 0);
assert!(flags & CLIENT_SECURE_CONNECTION != 0);
}
#[test]
fn test_default_config() {
let config = MySqlConfig::default();
assert_eq!(config.host, "localhost");
assert_eq!(config.port, 3306);
assert_eq!(config.ssl_mode, SslMode::Disable);
assert!(!config.compression);
assert!(!config.local_infile);
}
#[test]
fn test_tls_config_builder() {
let tls = TlsConfig::new()
.ca_cert("/path/to/ca.pem")
.client_cert("/path/to/client.pem")
.client_key("/path/to/client-key.pem")
.server_name("db.example.com");
assert_eq!(tls.ca_cert_path, Some(PathBuf::from("/path/to/ca.pem")));
assert_eq!(
tls.client_cert_path,
Some(PathBuf::from("/path/to/client.pem"))
);
assert_eq!(
tls.client_key_path,
Some(PathBuf::from("/path/to/client-key.pem"))
);
assert_eq!(tls.server_name, Some("db.example.com".to_string()));
assert!(!tls.danger_skip_verify);
assert!(tls.has_client_cert());
}
#[test]
fn test_tls_config_skip_verify() {
let tls = TlsConfig::new().skip_verify(true);
assert!(tls.danger_skip_verify);
}
#[test]
fn test_mysql_config_with_tls() {
let config = MySqlConfig::new()
.host("db.example.com")
.ssl_mode(SslMode::VerifyCa)
.ca_cert("/etc/ssl/certs/ca.pem")
.client_cert(
"/home/user/.mysql/client-cert.pem",
"/home/user/.mysql/client-key.pem",
);
assert_eq!(config.ssl_mode, SslMode::VerifyCa);
assert_eq!(
config.tls_config.ca_cert_path,
Some(PathBuf::from("/etc/ssl/certs/ca.pem"))
);
assert!(config.tls_config.has_client_cert());
}
#[test]
fn test_tls_config_no_client_cert() {
let tls = TlsConfig::new().ca_cert("/path/to/ca.pem");
assert!(!tls.has_client_cert());
let tls = TlsConfig::new()
.ca_cert("/path/to/ca.pem")
.client_cert("/path/to/client.pem");
assert!(!tls.has_client_cert());
}
}