use http::{Response, StatusCode, Version};
use tokio::io::copy_bidirectional;
use tokio::net::TcpStream;
use tracing::{debug, error, warn};
use crate::config::BasicAuthUser;
use crate::{Body, ProxyError, empty_body, full_body, goals};
pub struct ForwardProxy {
auth_users: Vec<ProxyAuthUser>,
}
struct ProxyAuthUser {
username: String,
password_hash: String,
is_bcrypt: bool,
}
impl ForwardProxy {
pub fn new(auth_users: &[BasicAuthUser]) -> Self {
let auth_users = auth_users
.iter()
.map(|u| {
let is_bcrypt = u.password_hash.starts_with("$2b$")
|| u.password_hash.starts_with("$2a$")
|| u.password_hash.starts_with("$2y$");
ProxyAuthUser {
username: u.username.clone(),
password_hash: u.password_hash.clone(),
is_bcrypt,
}
})
.collect();
Self { auth_users }
}
}
#[salvo::async_trait]
impl salvo::Handler for ForwardProxy {
async fn handle(
&self,
req: &mut salvo::Request,
_depot: &mut salvo::Depot,
res: &mut salvo::Response,
ctrl: &mut salvo::FlowCtrl,
) {
let client_addr = crate::hoops::client_addr(req);
let request = match goals::strip_request(req) {
Ok(r) => r,
Err(e) => {
goals::merge_response(res, e.into_response());
ctrl.skip_rest();
return;
}
};
let response = self
.run(request, client_addr)
.await
.unwrap_or_else(|e| e.into_response());
goals::merge_response(res, response);
ctrl.skip_rest();
}
}
impl ForwardProxy {
async fn run(
&self,
mut request: http::Request<Body>,
_client_addr: std::net::SocketAddr,
) -> Result<Response<Body>, ProxyError> {
if request.method() != http::Method::CONNECT {
return Err(ProxyError::BadRequest(
"forward proxy only supports CONNECT".into(),
));
}
if !self.auth_users.is_empty() {
match extract_proxy_credentials(&request) {
Some((username, password)) => {
let ok = self
.auth_users
.iter()
.any(|u| verify_proxy_user(u, &username, &password));
if !ok {
debug!(
username = username.as_str(),
"proxy authentication failed, returning 407"
);
return Ok(proxy_auth_required_response());
}
}
None => {
debug!("no Proxy-Authorization header, returning 407");
return Ok(proxy_auth_required_response());
}
}
}
if request.version() == Version::HTTP_2 {
warn!("HTTP/2 CONNECT tunnel is not supported; client should use HTTP/1.1");
return Ok(Response::builder()
.status(StatusCode::NOT_IMPLEMENTED)
.body(crate::full_body(
"HTTP/2 CONNECT tunneling is not supported; use HTTP/1.1",
))?);
}
let authority = request
.uri()
.authority()
.map(|a| a.to_string())
.or_else(|| {
request.uri().host().map(|h| {
let port = request.uri().port_u16().unwrap_or(443);
format!("{h}:{port}")
})
})
.ok_or_else(|| ProxyError::BadRequest("CONNECT request missing authority".into()))?;
debug!(target = %authority, "CONNECT tunnel request");
let upstream = TcpStream::connect(&authority)
.await
.map_err(|e| ProxyError::Internal(format!("failed to connect to {authority}: {e}")))?;
upstream.set_nodelay(true).ok();
let client_upgrade = hyper::upgrade::on(&mut request);
let response = Response::builder()
.status(StatusCode::OK)
.body(empty_body())?;
tokio::spawn(async move {
match client_upgrade.await {
Ok(client_io) => {
let mut client_io = hyper_util::rt::TokioIo::new(client_io);
let mut upstream = upstream;
match copy_bidirectional(&mut client_io, &mut upstream).await {
Ok((up, down)) => {
debug!(
bytes_up = up,
bytes_down = down,
target = %authority,
"CONNECT tunnel closed"
);
}
Err(e) => {
debug!(error = %e, target = %authority, "CONNECT tunnel error");
}
}
}
Err(e) => {
error!(error = %e, "CONNECT upgrade failed");
}
}
});
Ok(response)
}
}
fn extract_proxy_credentials(req: &http::Request<Body>) -> Option<(String, String)> {
let header_value = req.headers().get("proxy-authorization")?.to_str().ok()?;
let encoded = header_value.strip_prefix("Basic ")?;
let decoded_bytes = base64_decode(encoded)?;
let decoded = String::from_utf8(decoded_bytes).ok()?;
let (username, password) = decoded.split_once(':')?;
Some((username.to_string(), password.to_string()))
}
fn base64_decode(input: &str) -> Option<Vec<u8>> {
const TABLE: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
let input = input.trim();
if input.is_empty() {
return Some(Vec::new());
}
let mut output = Vec::with_capacity(input.len() * 3 / 4);
let mut buf: u32 = 0;
let mut bits: u32 = 0;
for &b in input.as_bytes() {
if b == b'=' {
break;
}
let val = match TABLE.iter().position(|&c| c == b) {
Some(v) => v as u32,
None => {
if b == b'\n' || b == b'\r' || b == b' ' {
continue;
}
return None;
}
};
buf = (buf << 6) | val;
bits += 6;
if bits >= 8 {
bits -= 8;
output.push((buf >> bits) as u8);
buf &= (1 << bits) - 1;
}
}
Some(output)
}
fn verify_proxy_user(user: &ProxyAuthUser, username: &str, password: &str) -> bool {
if user.username != username {
return false;
}
if user.is_bcrypt {
#[cfg(feature = "bcrypt")]
{
bcrypt::verify(password, &user.password_hash).unwrap_or(false)
}
#[cfg(not(feature = "bcrypt"))]
{
warn!("bcrypt password hash found but bcrypt feature is not enabled, rejecting");
false
}
} else {
constant_time_eq(password.as_bytes(), user.password_hash.as_bytes())
}
}
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
let mut diff: u8 = 0;
for (x, y) in a.iter().zip(b.iter()) {
diff |= x ^ y;
}
diff == 0
}
fn proxy_auth_required_response() -> Response<Body> {
Response::builder()
.status(StatusCode::PROXY_AUTHENTICATION_REQUIRED)
.header("Proxy-Authenticate", "Basic realm=\"gatel\"")
.body(full_body("Proxy Authentication Required"))
.unwrap()
}