use crate::audit;
use crate::error::{ProxyError, Result};
use std::net::SocketAddr;
use std::time::Duration;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio_rustls::TlsConnector;
use tracing::debug;
const UPSTREAM_CONNECT_TIMEOUT: Duration = Duration::from_secs(30);
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum UpstreamScheme {
Http,
Https,
}
pub enum UpstreamStrategy<'a> {
Direct { resolved_addrs: &'a [SocketAddr] },
ExternalProxy {
proxy_addr: &'a str,
proxy_auth_header: Option<&'a str>,
},
}
pub struct UpstreamSpec<'a> {
pub scheme: UpstreamScheme,
pub host: &'a str,
pub port: u16,
pub strategy: UpstreamStrategy<'a>,
pub tls_connector: &'a TlsConnector,
}
pub struct AuditCtx<'a> {
pub log: Option<&'a audit::SharedAuditLog>,
pub mode: audit::ProxyMode,
pub event_ctx: audit::EventContext<'a>,
pub target: &'a str,
pub method: &'a str,
pub path: &'a str,
}
pub async fn forward_request<S>(
inbound: &mut S,
request_bytes: &[u8],
body: &[u8],
upstream: UpstreamSpec<'_>,
audit: AuditCtx<'_>,
) -> Result<u16>
where
S: AsyncRead + AsyncWrite + Unpin,
{
let status = match upstream.scheme {
UpstreamScheme::Https => {
let mut tls_stream = open_https_upstream(&upstream).await?;
write_request(&mut tls_stream, request_bytes, body).await?;
stream_response(&mut tls_stream, inbound).await?
}
UpstreamScheme::Http => {
let mut tcp_stream = open_http_upstream(&upstream).await?;
write_request(&mut tcp_stream, request_bytes, body).await?;
stream_response(&mut tcp_stream, inbound).await?
}
};
audit::log_l7_request(
audit.log,
audit.mode,
&audit.event_ctx,
audit.target,
audit.method,
audit.path,
status,
);
Ok(status)
}
async fn open_https_upstream(
upstream: &UpstreamSpec<'_>,
) -> Result<tokio_rustls::client::TlsStream<TcpStream>> {
let tcp = open_tcp_upstream(upstream).await?;
let server_name =
rustls::pki_types::ServerName::try_from(upstream.host.to_string()).map_err(|_| {
ProxyError::UpstreamConnect {
host: upstream.host.to_string(),
reason: "invalid server name for TLS".to_string(),
}
})?;
upstream
.tls_connector
.connect(server_name, tcp)
.await
.map_err(|e| ProxyError::UpstreamConnect {
host: upstream.host.to_string(),
reason: format!("TLS handshake failed: {}", e),
})
}
async fn open_http_upstream(upstream: &UpstreamSpec<'_>) -> Result<TcpStream> {
open_tcp_upstream(upstream).await
}
async fn open_tcp_upstream(upstream: &UpstreamSpec<'_>) -> Result<TcpStream> {
match upstream.strategy {
UpstreamStrategy::Direct { resolved_addrs } => {
if resolved_addrs.is_empty() {
let addr = format!("{}:{}", upstream.host, upstream.port);
match tokio::time::timeout(UPSTREAM_CONNECT_TIMEOUT, TcpStream::connect(&addr))
.await
{
Ok(Ok(s)) => Ok(s),
Ok(Err(e)) => Err(ProxyError::UpstreamConnect {
host: upstream.host.to_string(),
reason: e.to_string(),
}),
Err(_) => Err(ProxyError::UpstreamConnect {
host: upstream.host.to_string(),
reason: "connection timed out".to_string(),
}),
}
} else {
connect_to_resolved(resolved_addrs, upstream.host).await
}
}
UpstreamStrategy::ExternalProxy {
proxy_addr,
proxy_auth_header,
} => crate::external::connect_via_proxy(
proxy_addr,
upstream.host,
upstream.port,
proxy_auth_header,
)
.await
.map_err(|e| match e {
ProxyError::ExternalProxy(reason) => ProxyError::UpstreamConnect {
host: upstream.host.to_string(),
reason,
},
other => other,
}),
}
}
async fn connect_to_resolved(addrs: &[SocketAddr], host: &str) -> Result<TcpStream> {
let mut last_err = None;
for addr in addrs {
match tokio::time::timeout(UPSTREAM_CONNECT_TIMEOUT, TcpStream::connect(addr)).await {
Ok(Ok(stream)) => return Ok(stream),
Ok(Err(e)) => {
debug!("Connect to {} failed: {}", addr, e);
last_err = Some(e.to_string());
}
Err(_) => {
debug!("Connect to {} timed out", addr);
last_err = Some("connection timed out".to_string());
}
}
}
Err(ProxyError::UpstreamConnect {
host: host.to_string(),
reason: last_err.unwrap_or_else(|| "no addresses to connect to".to_string()),
})
}
async fn write_request<S>(stream: &mut S, request: &[u8], body: &[u8]) -> Result<()>
where
S: AsyncWrite + Unpin,
{
stream.write_all(request).await?;
if !body.is_empty() {
stream.write_all(body).await?;
}
stream.flush().await?;
Ok(())
}
async fn stream_response<U, I>(upstream: &mut U, inbound: &mut I) -> Result<u16>
where
U: AsyncRead + AsyncWrite + Unpin,
I: AsyncWrite + Unpin,
{
let mut buf = [0u8; 8192];
let mut status_code: u16 = 502;
let mut first_chunk = true;
loop {
let n = match upstream.read(&mut buf).await {
Ok(0) => break,
Ok(n) => n,
Err(e) => {
debug!("Upstream read error: {}", e);
break;
}
};
if first_chunk {
status_code = parse_response_status(&buf[..n]);
first_chunk = false;
}
inbound.write_all(&buf[..n]).await?;
inbound.flush().await?;
}
Ok(status_code)
}
fn parse_response_status(data: &[u8]) -> u16 {
let line_end = data
.iter()
.position(|&b| b == b'\r' || b == b'\n')
.unwrap_or(data.len());
let first_line = &data[..line_end.min(64)];
if let Ok(line) = std::str::from_utf8(first_line) {
let mut parts = line.split_whitespace();
if let Some(version) = parts.next() {
if version.starts_with("HTTP/") {
if let Some(code_str) = parts.next() {
if code_str.len() == 3 {
return code_str.parse().unwrap_or(502);
}
}
}
}
}
502
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn parse_response_status_extracts_code() {
assert_eq!(parse_response_status(b"HTTP/1.1 200 OK\r\n"), 200);
assert_eq!(parse_response_status(b"HTTP/1.1 404 Not Found\r\n"), 404);
assert_eq!(parse_response_status(b"HTTP/1.1 502 Bad Gateway\r\n"), 502);
}
#[test]
fn parse_response_status_handles_garbage() {
assert_eq!(parse_response_status(b""), 502);
assert_eq!(parse_response_status(b"garbage"), 502);
assert_eq!(parse_response_status(b"NOT-HTTP 200 OK"), 502);
}
}