use core::str::Utf8Error;
use percent_encoding::percent_decode;
use std::{collections::BTreeMap, error::Error, fmt, str::Chars};
#[derive(Debug)]
pub enum ParseError {
InvalidDriver,
InvalidParams,
InvalidPath,
InvalidPort,
InvalidProtocol,
InvalidSocket,
MissingAddress,
MissingHost,
MissingProtocol,
MissingSocket,
Utf8Error(Utf8Error),
}
impl From<Utf8Error> for ParseError {
fn from(err: Utf8Error) -> Self {
Self::Utf8Error(err)
}
}
impl fmt::Display for ParseError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Self::InvalidDriver => write!(f, "invalid driver"),
Self::InvalidParams => write!(f, "invalid params"),
Self::InvalidPath => write!(f, "invalid absolute path"),
Self::InvalidPort => write!(f, "invalid port number"),
Self::InvalidProtocol => write!(f, "invalid protocol"),
Self::InvalidSocket => write!(f, "invalid socket"),
Self::MissingAddress => write!(f, "missing address"),
Self::MissingHost => write!(f, "missing host"),
Self::MissingProtocol => write!(f, "missing protocol"),
Self::MissingSocket => write!(f, "missing unix domain socket"),
Self::Utf8Error(ref err) => write!(f, "UTF-8 error: {err}"),
}
}
}
impl Error for ParseError {}
#[derive(Debug, Default)]
pub struct DSN {
pub driver: String,
pub username: Option<String>,
pub password: Option<String>,
pub protocol: String,
pub address: String,
pub host: Option<String>,
pub port: Option<u16>,
pub database: Option<String>,
pub socket: Option<String>,
pub params: BTreeMap<String, String>,
}
pub fn parse(input: &str) -> Result<DSN, ParseError> {
let mut dsn = DSN::default();
let chars = &mut input.chars();
dsn.driver = get_driver(chars)?;
let (user, pass) = get_username_password(chars)?;
if !user.is_empty() {
dsn.username = Some(user);
}
if !pass.is_empty() {
dsn.password = Some(pass);
}
dsn.protocol = get_protocol(chars)?;
dsn.address = get_address(chars)?;
match dsn.protocol.as_ref() {
"unix" => {
if !dsn.address.starts_with('/') {
return Err(ParseError::InvalidSocket);
}
dsn.socket = Some(dsn.address.clone());
}
"file" => {
if !dsn.address.starts_with('/') {
return Err(ParseError::InvalidPath);
}
}
_ => {
let (host, port) = get_host_port(&dsn.address)?;
dsn.host = Some(host);
if !port.is_empty() {
dsn.port = match port.parse::<u16>() {
Ok(n) => Some(n),
Err(_) => return Err(ParseError::InvalidPort),
}
}
}
}
let database = get_database(chars);
if !database.is_empty() {
dsn.database = Some(database);
}
let params = chars.as_str();
if !params.is_empty() {
dsn.params = get_params(chars.as_str())?;
}
Ok(dsn)
}
fn get_driver(chars: &mut Chars) -> Result<String, ParseError> {
let mut driver = String::new();
while let Some(c) = chars.next() {
if c == ':' {
if chars.next() == Some('/') && chars.next() == Some('/') {
break;
}
return Err(ParseError::InvalidDriver);
}
driver.push(c);
}
Ok(driver)
}
fn get_username_password(chars: &mut Chars) -> Result<(String, String), ParseError> {
let mut username = String::new();
let mut password = String::new();
let mut has_password = true;
for c in chars.by_ref() {
match c {
'@' => {
has_password = false;
break;
}
':' => {
break;
}
_ => username.push(c),
}
}
username = percent_decode(username.as_bytes()).decode_utf8()?.into();
if has_password {
for c in chars {
match c {
'@' => break,
_ => password.push(c),
}
}
password = percent_decode(password.as_bytes()).decode_utf8()?.into();
}
Ok((username, password))
}
fn get_protocol(chars: &mut Chars) -> Result<String, ParseError> {
let mut protocol = String::new();
for c in chars {
match c {
'(' => {
if protocol.is_empty() {
return Err(ParseError::MissingProtocol);
}
break;
}
_ => protocol.push(c),
}
}
Ok(protocol)
}
fn get_address(chars: &mut Chars) -> Result<String, ParseError> {
let mut address = String::new();
for c in chars {
match c {
')' => {
if address.is_empty() {
return Err(ParseError::MissingAddress);
}
break;
}
_ => address.push(c),
}
}
Ok(address)
}
fn get_host_port(address: &str) -> Result<(String, String), ParseError> {
let mut host = String::new();
let mut chars = address.chars();
for c in chars.by_ref() {
match c {
':' => {
if host.is_empty() {
return Err(ParseError::MissingHost);
}
break;
}
_ => host.push(c),
}
}
let port = chars.as_str();
Ok((host, port.into()))
}
fn get_database(chars: &mut Chars) -> String {
let mut database = String::new();
for c in chars {
match c {
'/' => {
if database.is_empty() {
continue;
}
}
'?' => break,
_ => database.push(c),
}
}
database
}
fn get_params(params_string: &str) -> Result<BTreeMap<String, String>, ParseError> {
let params: BTreeMap<String, String> = params_string
.split('&')
.map(|kv| kv.split('=').collect::<Vec<&str>>())
.map(|vec| {
if vec.len() != 2 {
return Err(ParseError::InvalidParams);
}
Ok((vec[0].to_string(), vec[1].to_string()))
})
.collect::<Result<_, _>>()?;
Ok(params)
}
#[cfg(test)]
mod tests {
use super::parse;
#[test]
fn test_parse_password() {
let dsn = parse(r#"mysql://user:pas':"'sword44444@host:port/database"#).unwrap();
assert_eq!(dsn.password.unwrap(), r#"pas':"'sword44444"#);
}
}