use percent_encoding::percent_decode;
use std::{collections::HashMap, error::Error, fmt, str::Chars};
#[derive(Debug)]
pub enum ParseError {
InvalidPort,
InvalidDriver,
InvalidSocket,
MissingUsername,
MissingProtocol,
MissingAddress,
MissingHost,
MissingSocket,
}
impl fmt::Display for ParseError {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
self.description().fmt(fmt)
}
}
impl Error for ParseError {
fn description(&self) -> &str {
match *self {
ParseError::InvalidPort => "invalid port number",
ParseError::InvalidDriver => "invalid driver",
ParseError::InvalidSocket => "invalid socket",
ParseError::MissingUsername => "missing username",
ParseError::MissingProtocol => "missing protocol",
ParseError::MissingAddress => "missing address",
ParseError::MissingHost => "missing host",
ParseError::MissingSocket => "missing unix domain socket",
}
}
}
#[derive(Debug, Default)]
pub struct DSN {
pub driver: String,
pub username: String,
pub password: Option<String>,
pub protocol: String,
pub address: String,
pub host: String,
pub port: Option<u16>,
pub database: Option<String>,
pub socket: Option<String>,
pub params: HashMap<String, String>,
}
pub fn parse(input: &str) -> Result<DSN, ParseError> {
let mut dsn = DSN::default();
let chars = &mut input.chars();
dsn.driver = get_driver(chars)?;
let (user, pass) = get_username_password(chars)?;
dsn.username = user;
if pass.len() > 0 {
dsn.password = Some(pass);
} else {
dsn.password = None;
}
dsn.protocol = get_protocol(chars)?;
dsn.address = get_address(chars)?;
if dsn.protocol == "unix" {
if !dsn.address.starts_with("/") {
return Err(ParseError::InvalidSocket);
}
dsn.socket = Some(dsn.address.clone())
} else {
let (host, port) = get_host_port(dsn.address.clone())?;
dsn.host = host;
if port.len() > 0 {
dsn.port = match port.parse::<u16>() {
Ok(n) => Some(n),
Err(_) => return Err(ParseError::InvalidPort),
}
}
}
Ok(dsn)
}
fn get_driver(chars: &mut Chars) -> Result<String, ParseError> {
let mut driver = String::new();
while let Some(c) = chars.next() {
if c == ':' {
if chars.next() == Some('/') && chars.next() == Some('/') {
break;
}
return Err(ParseError::InvalidDriver);
}
driver.push(c);
}
Ok(driver)
}
fn get_username_password(chars: &mut Chars) -> Result<(String, String), ParseError> {
let mut username = String::new();
let mut password = String::new();
while let Some(c) = chars.next() {
match c {
'@' => {
if username.len() == 0 {
return Err(ParseError::MissingUsername);
}
break;
}
':' => {
if username.len() == 0 {
return Err(ParseError::MissingUsername);
}
break;
}
_ => username.push(c),
}
}
while let Some(c) = chars.next() {
match c {
'@' => break,
_ => password.push(c),
}
}
Ok((
percent_decode(username.as_bytes())
.decode_utf8()
.unwrap()
.into(),
percent_decode(password.as_bytes())
.decode_utf8()
.unwrap()
.into(),
))
}
fn get_protocol(chars: &mut Chars) -> Result<String, ParseError> {
let mut protocol = String::new();
while let Some(c) = chars.next() {
match c {
'(' => {
if protocol.len() == 0 {
return Err(ParseError::MissingProtocol);
}
break;
}
_ => protocol.push(c),
}
}
Ok(protocol)
}
fn get_address(chars: &mut Chars) -> Result<String, ParseError> {
let mut address = String::new();
while let Some(c) = chars.next() {
match c {
')' => {
if address.len() == 0 {
return Err(ParseError::MissingAddress);
}
break;
}
_ => address.push(c),
}
}
Ok(address)
}
fn get_host_port(address: String) -> Result<(String, String), ParseError> {
let mut host = String::new();
let mut chars = address.chars();
while let Some(c) = chars.next() {
match c {
':' => {
if host.len() == 0 {
return Err(ParseError::MissingHost);
}
break;
}
_ => host.push(c),
}
}
let port = chars.as_str();
Ok((host, port.into()))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse() {
let dsn = parse(r#"mysql://user:o%3Ao@tcp(localhost:3306)/database"#).unwrap();
println!("{:?}", dsn);
assert_eq!(dsn.driver, "mysql");
assert_eq!(dsn.username, "user");
assert_eq!(dsn.password.unwrap(), "o:o");
assert_eq!(dsn.protocol, "tcp");
assert_eq!(dsn.address, "localhost:3306");
assert_eq!(dsn.host, "localhost");
assert_eq!(dsn.port.unwrap(), 3306);
assert_eq!(dsn.database, None);
assert_eq!(dsn.socket, None);
}
#[test]
fn test_parse_password() {
let dsn = parse(r#"mysql://user:pas':"'sword44444@host:port/database"#).unwrap();
assert_eq!(dsn.password.unwrap(), r#"pas':"'sword44444"#);
}
}