use anyhow::{Context, Result, bail};
use base64::{Engine, engine::general_purpose::STANDARD};
use percent_encoding::percent_decode_str;
use std::{fmt, str::FromStr};
use url::Url;
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
pub enum ProxyScheme {
Http,
Socks5,
}
impl ProxyScheme {
pub fn as_str(self) -> &'static str {
match self {
Self::Http => "http",
Self::Socks5 => "socks5",
}
}
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct UpstreamProxy {
scheme: ProxyScheme,
host: String,
port: u16,
username: Option<String>,
password: Option<String>,
}
impl UpstreamProxy {
pub fn parse(raw: &str) -> Result<Self> {
raw.parse()
}
pub fn scheme(&self) -> ProxyScheme {
self.scheme
}
pub fn host(&self) -> &str {
&self.host
}
pub fn port(&self) -> u16 {
self.port
}
pub fn username(&self) -> Option<&str> {
self.username.as_deref()
}
pub fn password(&self) -> Option<&str> {
self.password.as_deref()
}
pub fn has_auth(&self) -> bool {
self.username.is_some() || self.password.is_some()
}
pub fn needs_bridge(&self) -> bool {
self.scheme != ProxyScheme::Http || self.has_auth()
}
pub fn command_line_url(&self) -> String {
format!(
"{}://{}:{}",
self.scheme.as_str(),
format_host_for_url(&self.host),
self.port
)
}
pub fn authority(&self) -> String {
format!("{}:{}", self.host, self.port)
}
pub fn basic_proxy_authorization(&self) -> Option<String> {
let username = self.username.as_ref()?;
let password = self.password.as_deref().unwrap_or_default();
let token = STANDARD.encode(format!("{username}:{password}"));
Some(format!("Basic {token}"))
}
}
impl FromStr for UpstreamProxy {
type Err = anyhow::Error;
fn from_str(raw: &str) -> Result<Self> {
let trimmed = raw.trim().trim_matches('"');
if trimmed.is_empty() {
bail!("proxy URL cannot be empty");
}
let normalized = if trimmed.contains("://") {
trimmed.to_string()
} else {
format!("http://{trimmed}")
};
let url =
Url::parse(&normalized).with_context(|| format!("invalid proxy URL: {trimmed}"))?;
let scheme = match url.scheme().to_ascii_lowercase().as_str() {
"http" | "https" => ProxyScheme::Http,
"socks" | "socks5" => ProxyScheme::Socks5,
other => bail!("unsupported proxy scheme: {other}"),
};
let host = url
.host_str()
.context("proxy URL is missing a host")?
.to_string();
let port = url
.port_or_known_default()
.or_else(|| default_port(scheme))
.context("proxy URL is missing a port")?;
let username = decode_non_empty(url.username());
let password = url.password().and_then(decode_non_empty);
Ok(Self {
scheme,
host,
port,
username,
password,
})
}
}
impl fmt::Display for ProxyScheme {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str(self.as_str())
}
}
fn default_port(scheme: ProxyScheme) -> Option<u16> {
Some(match scheme {
ProxyScheme::Http => 80,
ProxyScheme::Socks5 => 1080,
})
}
fn decode_non_empty(value: &str) -> Option<String> {
if value.is_empty() {
return None;
}
Some(percent_decode_str(value).decode_utf8_lossy().into_owned())
}
fn format_host_for_url(host: &str) -> String {
if host.contains(':') && !host.starts_with('[') {
format!("[{host}]")
} else {
host.to_string()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parses_http_proxy_without_scheme() {
let proxy = UpstreamProxy::parse("127.0.0.1:7890").unwrap();
assert_eq!(proxy.scheme(), ProxyScheme::Http);
assert_eq!(proxy.host(), "127.0.0.1");
assert_eq!(proxy.port(), 7890);
assert!(!proxy.needs_bridge());
assert_eq!(proxy.command_line_url(), "http://127.0.0.1:7890");
}
#[test]
fn parses_socks_alias_with_auth() {
let proxy = UpstreamProxy::parse("socks://user:p%40ss@localhost:1080").unwrap();
assert_eq!(proxy.scheme(), ProxyScheme::Socks5);
assert_eq!(proxy.username(), Some("user"));
assert_eq!(proxy.password(), Some("p@ss"));
assert!(proxy.needs_bridge());
assert_eq!(proxy.command_line_url(), "socks5://localhost:1080");
}
#[test]
fn authenticated_http_proxy_needs_bridge() {
let proxy = UpstreamProxy::parse("http://user:pass@example.com:8080").unwrap();
assert!(proxy.needs_bridge());
assert!(proxy.basic_proxy_authorization().is_some());
assert_eq!(proxy.command_line_url(), "http://example.com:8080");
}
}