use tokio::io::{AsyncBufReadExt, AsyncReadExt, BufReader};
use crate::error::ProxyError;
pub const MAX_BODY_LEN: usize = 64 * 1024 * 1024;
pub const MAX_HEADER_LINE_LEN: usize = 8 * 1024;
pub const MAX_HEADER_BYTES: usize = 64 * 1024;
pub const MAX_HEADER_COUNT: usize = 200;
pub(crate) async fn read_line_capped<R>(
reader: &mut BufReader<R>,
out: &mut String,
max_line: usize,
remaining: usize,
) -> Result<usize, ProxyError>
where
R: tokio::io::AsyncRead + Unpin,
{
let cap = max_line.min(remaining);
let mut line: Vec<u8> = Vec::new();
loop {
let available = reader.fill_buf().await?;
if available.is_empty() {
break; }
let (chunk_len, done) = match available.iter().position(|&b| b == b'\n') {
Some(pos) => (pos + 1, true),
None => (available.len(), false),
};
if line.len() + chunk_len > cap {
return Err(ProxyError::Config(format!(
"HTTP header line exceeds maximum {cap} bytes; refusing (fail-closed)"
)));
}
line.extend_from_slice(&available[..chunk_len]);
reader.consume(chunk_len);
if done {
break;
}
}
let n = line.len();
let text = String::from_utf8(line).map_err(|_| ProxyError::Config("invalid UTF-8 in HTTP header line".into()))?;
out.push_str(&text);
Ok(n)
}
#[derive(Debug, Clone)]
pub struct HttpRequest {
pub method: String,
pub target: String,
pub version: String,
pub headers: Vec<(String, String)>,
pub body: Vec<u8>,
}
impl HttpRequest {
pub fn header(&self, name: &str) -> Option<&str> {
self.headers
.iter()
.find(|(k, _)| k.eq_ignore_ascii_case(name))
.map(|(_, v)| v.as_str())
}
}
pub async fn read_http_request<R>(reader: &mut BufReader<R>) -> Result<Option<HttpRequest>, ProxyError>
where
R: tokio::io::AsyncRead + Unpin,
{
let mut head_budget = MAX_HEADER_BYTES;
let mut request_line = String::new();
let n = read_line_capped(reader, &mut request_line, MAX_HEADER_LINE_LEN, head_budget).await?;
if n == 0 {
return Ok(None);
}
head_budget -= n;
let trimmed = request_line.trim_end_matches(['\r', '\n']);
let parts: Vec<&str> = trimmed.splitn(3, ' ').collect();
if parts.len() != 3 {
return Err(ProxyError::Config(format!("malformed HTTP request line: {trimmed:?}")));
}
let method = parts[0].to_string();
let target = parts[1].to_string();
let version = parts[2].to_string();
let mut headers: Vec<(String, String)> = Vec::new();
loop {
let mut line = String::new();
let n = read_line_capped(reader, &mut line, MAX_HEADER_LINE_LEN, head_budget).await?;
if n == 0 {
return Err(ProxyError::Config("unexpected EOF reading headers".into()));
}
head_budget -= n;
let trimmed = line.trim_end_matches(['\r', '\n']);
if trimmed.is_empty() {
break;
}
if headers.len() >= MAX_HEADER_COUNT {
return Err(ProxyError::Config(format!(
"HTTP request exceeds maximum {MAX_HEADER_COUNT} header lines; refusing (fail-closed)"
)));
}
if let Some((k, v)) = trimmed.split_once(':') {
headers.push((k.trim().to_string(), v.trim().to_string()));
} else {
return Err(ProxyError::Config(format!("malformed header line: {trimmed:?}")));
}
}
if headers
.iter()
.any(|(k, v)| k.eq_ignore_ascii_case("transfer-encoding") && v.to_ascii_lowercase().contains("chunked"))
{
return Err(ProxyError::Config(
"transfer-encoding: chunked request bodies are not inspectable; refusing (fail-closed)".into(),
));
}
let content_length: usize = headers
.iter()
.find(|(k, _)| k.eq_ignore_ascii_case("content-length"))
.and_then(|(_, v)| v.parse().ok())
.unwrap_or(0);
if content_length > MAX_BODY_LEN {
return Err(ProxyError::Config(format!(
"request Content-Length {content_length} exceeds maximum {MAX_BODY_LEN}; refusing (fail-closed)"
)));
}
let mut body = vec![0u8; content_length];
if content_length > 0 {
reader.read_exact(&mut body).await?;
}
Ok(Some(HttpRequest {
method,
target,
version,
headers,
body,
}))
}
pub fn serialize_http_request(req: &HttpRequest, new_body: &[u8]) -> Vec<u8> {
serialize_http_request_with_auth(req, new_body, None)
}
pub fn serialize_http_request_with_auth(req: &HttpRequest, new_body: &[u8], injected_auth: Option<&[u8]>) -> Vec<u8> {
let mut out = Vec::with_capacity(req.body.len() + new_body.len() + 256);
out.extend_from_slice(req.method.as_bytes());
out.push(b' ');
out.extend_from_slice(req.target.as_bytes());
out.push(b' ');
out.extend_from_slice(req.version.as_bytes());
out.extend_from_slice(b"\r\n");
for (k, v) in &req.headers {
if k.eq_ignore_ascii_case("content-length")
|| k.eq_ignore_ascii_case("transfer-encoding")
|| k.eq_ignore_ascii_case("connection")
{
continue;
}
if injected_auth.is_some() && (k.eq_ignore_ascii_case("authorization") || k.eq_ignore_ascii_case("x-api-key")) {
continue;
}
out.extend_from_slice(k.as_bytes());
out.extend_from_slice(b": ");
out.extend_from_slice(v.as_bytes());
out.extend_from_slice(b"\r\n");
}
if let Some(auth) = injected_auth {
out.extend_from_slice(b"Authorization: ");
out.extend_from_slice(auth);
out.extend_from_slice(b"\r\n");
}
out.extend_from_slice(b"Connection: close\r\n");
out.extend_from_slice(b"Content-Length: ");
out.extend_from_slice(new_body.len().to_string().as_bytes());
out.extend_from_slice(b"\r\n\r\n");
out.extend_from_slice(new_body);
out
}
#[derive(Debug, Clone)]
pub struct HttpResponse {
pub version: String,
pub status_code: String,
pub reason: String,
pub headers: Vec<(String, String)>,
pub body: Vec<u8>,
}
pub async fn read_http_response<R>(reader: &mut BufReader<R>) -> Result<Option<HttpResponse>, ProxyError>
where
R: tokio::io::AsyncRead + Unpin,
{
let mut head_budget = MAX_HEADER_BYTES;
let mut status_line = String::new();
let n = read_line_capped(reader, &mut status_line, MAX_HEADER_LINE_LEN, head_budget).await?;
if n == 0 {
return Ok(None);
}
head_budget -= n;
let trimmed = status_line.trim_end_matches(['\r', '\n']);
let parts: Vec<&str> = trimmed.splitn(3, ' ').collect();
if parts.len() < 2 {
return Err(ProxyError::Config(format!("malformed HTTP status line: {trimmed:?}")));
}
let version = parts[0].to_string();
let status_code = parts[1].to_string();
let reason = parts.get(2).copied().unwrap_or("").to_string();
let mut headers: Vec<(String, String)> = Vec::new();
loop {
let mut line = String::new();
let n = read_line_capped(reader, &mut line, MAX_HEADER_LINE_LEN, head_budget).await?;
if n == 0 {
return Err(ProxyError::Config("unexpected EOF reading response headers".into()));
}
head_budget -= n;
let trimmed = line.trim_end_matches(['\r', '\n']);
if trimmed.is_empty() {
break;
}
if headers.len() >= MAX_HEADER_COUNT {
return Err(ProxyError::Config(format!(
"HTTP response exceeds maximum {MAX_HEADER_COUNT} header lines; refusing (fail-closed)"
)));
}
if let Some((k, v)) = trimmed.split_once(':') {
headers.push((k.trim().to_string(), v.trim().to_string()));
} else {
return Err(ProxyError::Config(format!(
"malformed response header line: {trimmed:?}"
)));
}
}
let content_length: usize = headers
.iter()
.find(|(k, _)| k.eq_ignore_ascii_case("content-length"))
.and_then(|(_, v)| v.parse().ok())
.unwrap_or(0);
if content_length > MAX_BODY_LEN {
return Err(ProxyError::Config(format!(
"response Content-Length {content_length} exceeds maximum {MAX_BODY_LEN}; refusing (fail-closed)"
)));
}
let mut body = vec![0u8; content_length];
if content_length > 0 {
reader.read_exact(&mut body).await?;
}
Ok(Some(HttpResponse {
version,
status_code,
reason,
headers,
body,
}))
}
pub fn serialize_http_response(resp: &HttpResponse, new_body: &[u8]) -> Vec<u8> {
let mut out = Vec::with_capacity(resp.body.len() + new_body.len() + 256);
out.extend_from_slice(resp.version.as_bytes());
out.push(b' ');
out.extend_from_slice(resp.status_code.as_bytes());
if !resp.reason.is_empty() {
out.push(b' ');
out.extend_from_slice(resp.reason.as_bytes());
}
out.extend_from_slice(b"\r\n");
for (k, v) in &resp.headers {
if k.eq_ignore_ascii_case("content-length") || k.eq_ignore_ascii_case("transfer-encoding") {
continue;
}
out.extend_from_slice(k.as_bytes());
out.extend_from_slice(b": ");
out.extend_from_slice(v.as_bytes());
out.extend_from_slice(b"\r\n");
}
out.extend_from_slice(b"Content-Length: ");
out.extend_from_slice(new_body.len().to_string().as_bytes());
out.extend_from_slice(b"\r\n\r\n");
out.extend_from_slice(new_body);
out
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
fn make_reader(bytes: &[u8]) -> BufReader<Cursor<Vec<u8>>> {
BufReader::new(Cursor::new(bytes.to_vec()))
}
#[tokio::test]
async fn parses_post_with_content_length_body() {
let raw = b"POST /v1/chat/completions HTTP/1.1\r\n\
Host: api.openai.com\r\n\
Content-Type: application/json\r\n\
Content-Length: 13\r\n\
\r\n\
{\"hello\":1}\r\n";
let mut reader = make_reader(raw);
let req = read_http_request(&mut reader).await.unwrap().expect("request present");
assert_eq!(req.method, "POST");
assert_eq!(req.target, "/v1/chat/completions");
assert_eq!(req.version, "HTTP/1.1");
assert_eq!(req.header("host"), Some("api.openai.com"));
assert_eq!(req.body.len(), 13);
assert_eq!(&req.body, b"{\"hello\":1}\r\n");
}
#[tokio::test]
async fn parses_get_with_no_body() {
let raw = b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n";
let mut reader = make_reader(raw);
let req = read_http_request(&mut reader).await.unwrap().expect("request present");
assert_eq!(req.method, "GET");
assert!(req.body.is_empty());
}
#[tokio::test]
async fn returns_none_on_clean_eof() {
let mut reader = make_reader(b"");
let req = read_http_request(&mut reader).await.unwrap();
assert!(req.is_none(), "EOF before request line must return None");
}
#[tokio::test]
async fn header_lookup_is_case_insensitive() {
let raw = b"GET / HTTP/1.1\r\nX-Custom: v\r\n\r\n";
let mut reader = make_reader(raw);
let req = read_http_request(&mut reader).await.unwrap().unwrap();
assert_eq!(req.header("x-custom"), Some("v"));
assert_eq!(req.header("X-CUSTOM"), Some("v"));
}
#[tokio::test]
async fn serialize_rewrites_content_length_for_smaller_body() {
let raw = b"POST /v1/chat/completions HTTP/1.1\r\n\
Host: api.openai.com\r\n\
Content-Type: application/json\r\n\
Content-Length: 20\r\n\
\r\n\
01234567890123456789";
let mut reader = make_reader(raw);
let req = read_http_request(&mut reader).await.unwrap().unwrap();
let new_body = b"short";
let wire = serialize_http_request(&req, new_body);
let text = std::str::from_utf8(&wire).unwrap();
assert!(text.starts_with("POST /v1/chat/completions HTTP/1.1\r\n"));
assert!(text.contains("Host: api.openai.com\r\n"));
assert!(text.contains("Content-Type: application/json\r\n"));
assert_eq!(text.matches("Content-Length: 5\r\n").count(), 1);
assert!(!text.contains("Content-Length: 20"));
assert!(text.ends_with("\r\n\r\nshort"));
}
#[test]
fn serialize_drops_transfer_encoding_header() {
let req = HttpRequest {
method: "POST".into(),
target: "/".into(),
version: "HTTP/1.1".into(),
headers: vec![
("Host".into(), "x.example.com".into()),
("Transfer-Encoding".into(), "chunked".into()),
],
body: Vec::new(),
};
let wire = serialize_http_request(&req, b"hi");
let text = std::str::from_utf8(&wire).unwrap();
assert!(
!text.to_ascii_lowercase().contains("transfer-encoding"),
"serialized request must drop Transfer-Encoding when replacing body, got: {text}",
);
assert!(text.contains("Content-Length: 2\r\n"));
}
#[test]
fn serialize_forces_connection_close_and_strips_inbound_connection() {
let req = HttpRequest {
method: "POST".into(),
target: "/v1/x".into(),
version: "HTTP/1.1".into(),
headers: vec![
("Host".into(), "api.openai.com".into()),
("Connection".into(), "keep-alive".into()),
],
body: Vec::new(),
};
let wire = serialize_http_request(&req, b"hi");
let text = std::str::from_utf8(&wire).unwrap();
assert_eq!(
text.to_ascii_lowercase().matches("connection:").count(),
1,
"exactly one Connection header expected: {text}"
);
assert!(
text.contains("Connection: close\r\n"),
"must force Connection: close: {text}"
);
assert!(
!text.to_ascii_lowercase().contains("keep-alive"),
"inbound keep-alive Connection header must be dropped: {text}"
);
}
#[tokio::test]
async fn inject_auth_strips_agent_header_and_appends_real_key() {
let raw = b"POST /v1/chat/completions HTTP/1.1\r\n\
Host: api.openai.com\r\n\
Authorization: Bearer agent-bogus-token\r\n\
x-api-key: agent-bogus-key\r\n\
Content-Length: 2\r\n\
\r\n\
hi";
let mut reader = make_reader(raw);
let req = read_http_request(&mut reader).await.unwrap().unwrap();
let wire = serialize_http_request_with_auth(&req, &req.body, Some(b"Bearer sk-REAL-PROVIDER-KEY"));
let text = std::str::from_utf8(&wire).unwrap();
assert!(
!text.contains("agent-bogus-token"),
"agent Authorization must be stripped: {text}"
);
assert!(
!text.contains("agent-bogus-key"),
"agent x-api-key must be stripped: {text}"
);
assert_eq!(
text.matches("Authorization: Bearer sk-REAL-PROVIDER-KEY\r\n").count(),
1,
"injected Authorization must appear exactly once: {text}"
);
assert!(
text.contains("Host: api.openai.com\r\n"),
"non-credential headers preserved"
);
}
#[tokio::test]
async fn inject_auth_none_forwards_agent_header_verbatim() {
let raw = b"POST / HTTP/1.1\r\n\
Host: api.openai.com\r\n\
Authorization: Bearer agent-token\r\n\
Content-Length: 2\r\n\
\r\n\
hi";
let mut reader = make_reader(raw);
let req = read_http_request(&mut reader).await.unwrap().unwrap();
let wire = serialize_http_request_with_auth(&req, &req.body, None);
let text = std::str::from_utf8(&wire).unwrap();
assert!(
text.contains("Authorization: Bearer agent-token\r\n"),
"agent header forwarded: {text}"
);
}
#[tokio::test]
async fn malformed_request_line_is_config_error() {
let mut reader = make_reader(b"GET-ONLY\r\n\r\n");
let err = read_http_request(&mut reader).await.unwrap_err();
assert!(
matches!(err, ProxyError::Config(_)),
"expected Config error, got {err:?}"
);
}
#[tokio::test]
async fn malformed_header_line_is_config_error() {
let mut reader = make_reader(b"GET / HTTP/1.1\r\nNoColonHere\r\n\r\n");
let err = read_http_request(&mut reader).await.unwrap_err();
assert!(
matches!(err, ProxyError::Config(_)),
"expected Config error, got {err:?}"
);
}
#[tokio::test]
async fn rejects_chunked_transfer_encoding_request() {
let raw = b"POST /v1/chat/completions HTTP/1.1\r\n\
Host: api.openai.com\r\n\
Transfer-Encoding: chunked\r\n\
\r\n\
5\r\nhello\r\n0\r\n\r\n";
let mut reader = make_reader(raw);
let err = read_http_request(&mut reader).await.unwrap_err();
assert!(
matches!(&err, ProxyError::Config(msg) if msg.contains("chunked")),
"expected fail-closed chunked rejection, got: {err:?}"
);
}
#[tokio::test]
async fn rejects_request_content_length_over_cap_without_allocating() {
let raw = b"POST /v1/chat/completions HTTP/1.1\r\n\
Host: api.openai.com\r\n\
Content-Length: 2000000000\r\n\
\r\n";
let mut reader = make_reader(raw);
let err = read_http_request(&mut reader).await.unwrap_err();
assert!(
matches!(&err, ProxyError::Config(msg) if msg.contains("exceeds maximum")),
"expected over-cap rejection, got: {err:?}"
);
}
#[tokio::test]
async fn accepts_request_content_length_at_cap_boundary() {
let raw = format!("POST / HTTP/1.1\r\nHost: x\r\nContent-Length: {MAX_BODY_LEN}\r\n\r\n");
let mut reader = make_reader(raw.as_bytes());
let err = read_http_request(&mut reader).await.unwrap_err();
assert!(
!matches!(&err, ProxyError::Config(msg) if msg.contains("exceeds maximum")),
"cap boundary must not be rejected by the cap check, got: {err:?}"
);
}
#[tokio::test]
async fn rejects_response_content_length_over_cap_without_allocating() {
let raw = b"HTTP/1.1 200 OK\r\n\
Content-Length: 2000000000\r\n\
\r\n";
let mut reader = make_reader(raw);
let err = read_http_response(&mut reader).await.unwrap_err();
assert!(
matches!(&err, ProxyError::Config(msg) if msg.contains("exceeds maximum")),
"expected over-cap rejection, got: {err:?}"
);
}
#[tokio::test]
async fn unexpected_eof_reading_request_headers_is_error() {
let mut reader = make_reader(b"GET / HTTP/1.1\r\n");
let err = read_http_request(&mut reader).await.unwrap_err();
assert!(
matches!(err, ProxyError::Config(_)),
"expected Config error, got {err:?}"
);
}
#[tokio::test]
async fn parses_response_with_content_length_body() {
let raw = b"HTTP/1.1 200 OK\r\n\
Content-Type: application/json\r\n\
Content-Length: 11\r\n\
\r\n\
{\"ok\":true}";
let mut reader = make_reader(raw);
let resp = read_http_response(&mut reader)
.await
.unwrap()
.expect("response present");
assert_eq!(resp.version, "HTTP/1.1");
assert_eq!(resp.status_code, "200");
assert_eq!(resp.reason, "OK");
assert_eq!(&resp.body, b"{\"ok\":true}");
}
#[tokio::test]
async fn response_returns_none_on_clean_eof() {
let mut reader = make_reader(b"");
assert!(read_http_response(&mut reader).await.unwrap().is_none());
}
#[tokio::test]
async fn response_with_no_reason_phrase_parses() {
let raw = b"HTTP/1.1 204\r\nContent-Length: 0\r\n\r\n";
let mut reader = make_reader(raw);
let resp = read_http_response(&mut reader).await.unwrap().unwrap();
assert_eq!(resp.status_code, "204");
assert_eq!(resp.reason, "");
assert!(resp.body.is_empty());
}
#[tokio::test]
async fn malformed_status_line_is_config_error() {
let mut reader = make_reader(b"GARBAGE\r\n\r\n");
let err = read_http_response(&mut reader).await.unwrap_err();
assert!(
matches!(err, ProxyError::Config(_)),
"expected Config error, got {err:?}"
);
}
#[tokio::test]
async fn malformed_response_header_line_is_config_error() {
let mut reader = make_reader(b"HTTP/1.1 200 OK\r\nNoColonHeader\r\n\r\n");
let err = read_http_response(&mut reader).await.unwrap_err();
assert!(
matches!(err, ProxyError::Config(_)),
"expected Config error, got {err:?}"
);
}
#[tokio::test]
async fn unexpected_eof_reading_response_headers_is_error() {
let mut reader = make_reader(b"HTTP/1.1 200 OK\r\n");
let err = read_http_response(&mut reader).await.unwrap_err();
assert!(
matches!(err, ProxyError::Config(_)),
"expected Config error, got {err:?}"
);
}
#[tokio::test]
async fn serialize_response_rewrites_content_length_and_drops_transfer_encoding() {
let raw = b"HTTP/1.1 200 OK\r\n\
Content-Type: application/json\r\n\
Transfer-Encoding: chunked\r\n\
Content-Length: 3\r\n\
\r\n\
old";
let mut reader = make_reader(raw);
let resp = read_http_response(&mut reader).await.unwrap().unwrap();
let new_body = b"redacted";
let wire = serialize_http_response(&resp, new_body);
let text = std::str::from_utf8(&wire).unwrap();
assert!(text.starts_with("HTTP/1.1 200 OK\r\n"));
assert!(text.contains("Content-Type: application/json\r\n"));
assert!(
!text.to_ascii_lowercase().contains("transfer-encoding"),
"Transfer-Encoding must be stripped: {text}"
);
assert_eq!(text.matches("Content-Length:").count(), 1);
assert_eq!(text.matches("Content-Length: 8\r\n").count(), 1);
assert!(text.ends_with("\r\n\r\nredacted"));
}
#[tokio::test]
async fn serialize_response_omits_empty_reason_phrase() {
let raw = b"HTTP/1.1 204\r\nContent-Length: 0\r\n\r\n";
let mut reader = make_reader(raw);
let resp = read_http_response(&mut reader).await.unwrap().unwrap();
let wire = serialize_http_response(&resp, b"");
let text = std::str::from_utf8(&wire).unwrap();
assert!(
text.starts_with("HTTP/1.1 204\r\n"),
"no trailing reason space: {text:?}"
);
}
#[tokio::test]
async fn rejects_oversized_request_header_line_before_oom() {
let big_value = "a".repeat(MAX_HEADER_LINE_LEN + 1);
let raw = format!("GET / HTTP/1.1\r\nX-Big: {big_value}\r\n\r\n");
let mut reader = make_reader(raw.as_bytes());
let err = read_http_request(&mut reader).await.unwrap_err();
assert!(
matches!(&err, ProxyError::Config(msg) if msg.contains("exceeds maximum")),
"expected over-cap header rejection, got: {err:?}"
);
}
#[tokio::test]
async fn rejects_request_with_too_many_headers() {
let mut raw = String::from("GET / HTTP/1.1\r\n");
for i in 0..=MAX_HEADER_COUNT {
raw.push_str(&format!("X-H{i}: v\r\n"));
}
raw.push_str("\r\n");
let mut reader = make_reader(raw.as_bytes());
let err = read_http_request(&mut reader).await.unwrap_err();
assert!(
matches!(&err, ProxyError::Config(msg) if msg.contains("header lines")),
"expected too-many-headers rejection, got: {err:?}"
);
}
#[tokio::test]
async fn header_within_caps_still_parses() {
let raw = b"POST / HTTP/1.1\r\nHost: api.openai.com\r\nContent-Length: 2\r\n\r\nhi";
let mut reader = make_reader(raw);
let req = read_http_request(&mut reader).await.unwrap().expect("request present");
assert_eq!(req.header("host"), Some("api.openai.com"));
assert_eq!(&req.body, b"hi");
}
#[tokio::test]
async fn rejects_oversized_response_header_line_before_oom() {
let big_value = "a".repeat(MAX_HEADER_LINE_LEN + 1);
let raw = format!("HTTP/1.1 200 OK\r\nX-Big: {big_value}\r\n\r\n");
let mut reader = make_reader(raw.as_bytes());
let err = read_http_response(&mut reader).await.unwrap_err();
assert!(
matches!(&err, ProxyError::Config(msg) if msg.contains("exceeds maximum")),
"expected over-cap header rejection, got: {err:?}"
);
}
}