use crate::error::SqlError;
use base64::Engine;
use secrecy::{ExposeSecret, SecretString};
use std::env;
#[derive(Debug, Clone)]
pub struct ProxyConfig {
pub url: String,
pub host: String,
pub port: u16,
pub username: Option<String>,
pub password: Option<SecretString>,
}
impl ProxyConfig {
pub fn parse(url: &str) -> Result<Self, SqlError> {
let parsed =
::url::Url::parse(url).map_err(|e| SqlError::InvalidUrl(format!("proxy URL: {e}")))?;
let host = parsed
.host_str()
.ok_or_else(|| SqlError::InvalidUrl("proxy URL has no host".to_string()))?
.to_string();
let port = parsed.port().unwrap_or(8080);
let (username, password) = if let Some(info) = parsed.password() {
(
Some(parsed.username().to_string()),
Some(SecretString::new(info.to_string().into())),
)
} else {
(None, None)
};
Ok(ProxyConfig {
url: url.to_string(),
host,
port,
username,
password,
})
}
}
pub fn is_no_proxy(target_host: &str) -> bool {
let no_proxy = match env::var("NO_PROXY") {
Ok(v) if !v.is_empty() => v,
_ => return false,
};
let target_host = target_host.to_ascii_lowercase();
for pattern in no_proxy.split(',') {
let pattern = pattern.trim().to_ascii_lowercase();
if pattern.is_empty() {
continue;
}
if pattern == "*" {
return true;
}
let pattern_host = pattern.split(':').next().unwrap_or(&pattern);
if pattern_host.starts_with('.') {
if target_host.ends_with(pattern_host) {
return true;
}
}
else if target_host == pattern_host || target_host.ends_with(&format!(".{pattern_host}"))
{
return true;
}
}
false
}
pub fn resolve_proxy_from_env(_target_scheme: &str) -> Option<ProxyConfig> {
let try_env = |name: &str| -> Option<ProxyConfig> {
env::var(name)
.ok()
.filter(|s| !s.is_empty())
.and_then(|url| ProxyConfig::parse(&url).ok())
};
try_env("ALL_PROXY").or_else(|| {
try_env("HTTPS_PROXY").or_else(|| try_env("HTTP_PROXY"))
})
}
pub(crate) async fn http_connect(
proxy: &ProxyConfig,
target_host: &str,
target_port: u16,
) -> Result<tokio::net::TcpStream, SqlError> {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
let mut stream = tokio::net::TcpStream::connect((proxy.host.as_str(), proxy.port))
.await
.map_err(|e| {
SqlError::ConnectionFailed(format!(
"proxy connect to {}:{}: {e}",
proxy.host, proxy.port
))
})?;
let mut request = format!(
"CONNECT {target_host}:{target_port} HTTP/1.1\r\n\
Host: {target_host}:{target_port}\r\n"
);
if let (Some(u), Some(p)) = (&proxy.username, &proxy.password) {
let creds = format!("{}:{}", u, p.expose_secret());
let encoded = base64::prelude::BASE64_STANDARD.encode(creds);
request.push_str(&format!("Proxy-Authorization: Basic {encoded}\r\n"));
}
request.push_str("\r\n");
stream
.write_all(request.as_bytes())
.await
.map_err(|e| SqlError::ConnectionFailed(format!("proxy write: {e}")))?;
let mut buf = [0u8; 1024];
let n = stream
.read(&mut buf)
.await
.map_err(|e| SqlError::ConnectionFailed(format!("proxy read: {e}")))?;
let response = std::str::from_utf8(&buf[..n])
.map_err(|_| SqlError::ConnectionFailed("proxy returned non-UTF-8 response".to_string()))?;
let status_line = response.lines().next().unwrap_or("").trim();
if !status_line.starts_with("HTTP/1.1 200") && !status_line.starts_with("HTTP/1.0 200") {
return Err(SqlError::ConnectionFailed(format!(
"proxy error: {status_line}"
)));
}
Ok(stream)
}
pub struct ProxiedConnection {
pub(crate) inner: Box<dyn crate::connection::AsyncConnection>,
pub(crate) forwarder: Option<tokio::task::JoinHandle<()>>,
}
#[async_trait::async_trait]
impl crate::connection::AsyncConnection for ProxiedConnection {
async fn execute(&mut self, sql: &str) -> Result<crate::ExecutionSummary, crate::SqlError> {
self.inner.execute(sql).await
}
async fn query(&mut self, sql: &str) -> Result<crate::QueryResult, crate::SqlError> {
self.inner.query(sql).await
}
async fn query_stream(
&mut self,
sql: &str,
) -> Result<(Vec<crate::ColumnInfo>, crate::BoxRowStream<'_>), crate::SqlError> {
self.inner.query_stream(sql).await
}
async fn execute_multi(
&mut self,
sql: &str,
) -> Result<Vec<crate::StatementResult>, crate::SqlError> {
self.inner.execute_multi(sql).await
}
async fn ping(&mut self) -> Result<(), crate::SqlError> {
self.inner.ping().await
}
async fn list_tables(&mut self, schema: Option<&str>) -> Result<Vec<String>, crate::SqlError> {
self.inner.list_tables(schema).await
}
async fn list_schemas(
&mut self,
) -> Result<Vec<crate::connection::SchemaInfo>, crate::SqlError> {
self.inner.list_schemas().await
}
async fn describe_table(
&mut self,
schema: Option<&str>,
table: &str,
) -> Result<crate::QueryResult, crate::SqlError> {
self.inner.describe_table(schema, table).await
}
async fn primary_key(
&mut self,
schema: Option<&str>,
table: &str,
) -> Result<Vec<String>, crate::SqlError> {
self.inner.primary_key(schema, table).await
}
async fn list_foreign_keys(
&mut self,
schema: Option<&str>,
) -> Result<Vec<crate::ForeignKey>, crate::SqlError> {
self.inner.list_foreign_keys(schema).await
}
async fn bulk_insert_rows(
&mut self,
target: crate::connection::BulkInsert<'_>,
) -> Result<usize, crate::SqlError> {
self.inner.bulk_insert_rows(target).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Mutex;
static ENV_GUARD: Mutex<()> = Mutex::new(());
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
ENV_GUARD.lock().unwrap_or_else(|p| p.into_inner())
}
#[test]
fn test_parse_simple() {
let cfg = ProxyConfig::parse("http://proxy:8080").unwrap();
assert_eq!(cfg.host, "proxy");
assert_eq!(cfg.port, 8080);
assert_eq!(cfg.username, None);
assert!(cfg.password.is_none());
}
#[test]
fn test_parse_with_auth() {
let cfg = ProxyConfig::parse("http://user:pass@proxy:3128").unwrap();
assert_eq!(cfg.host, "proxy");
assert_eq!(cfg.port, 3128);
assert_eq!(cfg.username, Some("user".to_string()));
assert_eq!(cfg.password.as_ref().unwrap().expose_secret(), "pass");
}
#[test]
fn test_parse_no_port_uses_8080() {
let cfg = ProxyConfig::parse("http://proxy").unwrap();
assert_eq!(cfg.port, 8080);
}
#[test]
fn test_is_no_proxy_star() {
let _guard = env_lock();
unsafe {
std::env::set_var("NO_PROXY", "*");
assert!(is_no_proxy("anything"));
std::env::remove_var("NO_PROXY");
}
}
#[test]
fn test_is_no_proxy_exact() {
let _guard = env_lock();
unsafe {
std::env::set_var("NO_PROXY", "localhost");
assert!(is_no_proxy("localhost"));
assert!(!is_no_proxy("otherhost"));
std::env::remove_var("NO_PROXY");
}
}
#[test]
fn test_is_no_proxy_suffix() {
let _guard = env_lock();
unsafe {
std::env::set_var("NO_PROXY", ".example.com");
assert!(is_no_proxy("db.example.com"));
assert!(!is_no_proxy("example.com"));
std::env::remove_var("NO_PROXY");
}
}
#[test]
fn test_resolve_proxy_from_env_empty() {
let _guard = env_lock();
assert!(resolve_proxy_from_env("postgres").is_none());
}
}