use std::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Scheme {
Http,
Https,
WebSocket,
WebSocketSecure,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ParsedUrl {
pub scheme: Scheme,
pub host: String,
pub port: u16,
pub path: String,
}
impl ParsedUrl {
#[inline]
pub fn uses_tls(&self) -> bool {
matches!(self.scheme, Scheme::Https | Scheme::WebSocketSecure)
}
#[inline]
pub fn authority(&self) -> String {
if self.host.contains(':') {
format!("[{}]:{}", self.host, self.port)
} else {
format!("{}:{}", self.host, self.port)
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum UrlError {
MissingScheme,
UnsupportedScheme,
MissingHost,
InvalidPort,
}
impl fmt::Display for UrlError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::MissingScheme => f.write_str("missing URL scheme"),
Self::UnsupportedScheme => f.write_str("unsupported URL scheme"),
Self::MissingHost => f.write_str("missing URL host"),
Self::InvalidPort => f.write_str("invalid URL port"),
}
}
}
impl std::error::Error for UrlError {}
pub fn parse_http_url(input: &str) -> Result<ParsedUrl, UrlError> {
parse_url(
input,
&[
(Scheme::Http, "http://", 80),
(Scheme::Https, "https://", 443),
],
)
}
pub fn parse_ws_url(input: &str) -> Result<ParsedUrl, UrlError> {
parse_url(
input,
&[
(Scheme::WebSocket, "ws://", 80),
(Scheme::WebSocketSecure, "wss://", 443),
],
)
}
fn parse_url(input: &str, schemes: &[(Scheme, &'static str, u16)]) -> Result<ParsedUrl, UrlError> {
if !input.contains("://") {
return Err(UrlError::MissingScheme);
}
let Some((scheme, prefix, default_port)) = schemes
.iter()
.copied()
.find(|(_, prefix, _)| input.starts_with(prefix))
else {
return Err(UrlError::UnsupportedScheme);
};
let remainder = &input[prefix.len()..];
if remainder.is_empty() {
return Err(UrlError::MissingHost);
}
let (authority, path) = match remainder.split_once('/') {
Some((authority, path)) => (authority, format!("/{path}")),
None => (remainder, "/".to_string()),
};
if authority.is_empty() {
return Err(UrlError::MissingHost);
}
let (host, port) = if let Some(stripped) = authority.strip_prefix('[') {
let Some((host, remainder)) = stripped.split_once(']') else {
return Err(UrlError::MissingHost);
};
if host.is_empty() {
return Err(UrlError::MissingHost);
}
match remainder {
"" => (host.to_string(), default_port),
port_text if port_text.starts_with(':') => {
let port = port_text[1..]
.parse::<u16>()
.map_err(|_| UrlError::InvalidPort)?;
(host.to_string(), port)
}
_ => return Err(UrlError::InvalidPort),
}
} else {
match authority.rsplit_once(':') {
Some((host, port)) if !host.is_empty() && authority.matches(':').count() == 1 => {
let port = port.parse::<u16>().map_err(|_| UrlError::InvalidPort)?;
(host.to_string(), port)
}
_ => (authority.to_string(), default_port),
}
};
Ok(ParsedUrl {
scheme,
host,
port,
path,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_http_and_https_urls() {
let http = parse_http_url("http://example.com/path").unwrap();
assert_eq!(http.scheme, Scheme::Http);
assert_eq!(http.host, "example.com");
assert_eq!(http.port, 80);
assert_eq!(http.path, "/path");
assert!(!http.uses_tls());
let https = parse_http_url("https://example.com").unwrap();
assert_eq!(https.scheme, Scheme::Https);
assert_eq!(https.host, "example.com");
assert_eq!(https.port, 443);
assert_eq!(https.path, "/");
assert!(https.uses_tls());
}
#[test]
fn parse_ws_and_wss_urls() {
let ws = parse_ws_url("ws://example.com/slots").unwrap();
assert_eq!(ws.scheme, Scheme::WebSocket);
assert_eq!(ws.port, 80);
assert_eq!(ws.path, "/slots");
assert!(!ws.uses_tls());
let wss = parse_ws_url("wss://example.com").unwrap();
assert_eq!(wss.scheme, Scheme::WebSocketSecure);
assert_eq!(wss.port, 443);
assert_eq!(wss.path, "/");
assert!(wss.uses_tls());
}
#[test]
fn parse_bracketed_ipv6_url() {
let parsed = parse_http_url("https://[2001:db8::1]:8443/rpc").unwrap();
assert_eq!(parsed.scheme, Scheme::Https);
assert_eq!(parsed.host, "2001:db8::1");
assert_eq!(parsed.port, 8443);
assert_eq!(parsed.path, "/rpc");
assert_eq!(parsed.authority(), "[2001:db8::1]:8443");
}
#[test]
fn reject_unsupported_scheme() {
assert_eq!(
parse_http_url("ftp://example.com").unwrap_err(),
UrlError::UnsupportedScheme
);
}
}