use std::io::{Read, Write};
use std::net::{SocketAddr, TcpListener, TcpStream};
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
use std::time::{Duration, Instant};
const READ_TIMEOUT: Duration = Duration::from_secs(5);
const REQUEST_READ_DEADLINE: Duration = Duration::from_secs(5);
pub fn spawn_oneshot_http_responder(responses: Vec<&'static str>) -> (SocketAddr, Arc<AtomicU32>) {
let listener = TcpListener::bind("127.0.0.1:0").expect("bind ephemeral port");
let addr = listener.local_addr().expect("local_addr");
let counter = Arc::new(AtomicU32::new(0));
let counter_inner = counter.clone();
std::thread::spawn(move || {
for resp in responses.iter() {
let (stream, _) = match listener.accept() {
Ok(pair) => pair,
Err(_) => return,
};
counter_inner.fetch_add(1, Ordering::SeqCst);
serve_one(stream, resp);
}
let _ = listener.set_nonblocking(true);
let drain_deadline = Instant::now() + Duration::from_millis(250);
while Instant::now() < drain_deadline {
match listener.accept() {
Ok((stream, _)) => {
let _ = stream.set_nonblocking(false);
serve_one(
stream,
"HTTP/1.1 503 Service Unavailable\r\nContent-Length: 0\r\n\r\n",
);
}
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
std::thread::sleep(Duration::from_millis(10));
}
Err(_) => break,
}
}
});
(addr, counter)
}
pub fn spawn_request_capturing_responder(
response: &'static str,
) -> (SocketAddr, Arc<std::sync::Mutex<String>>) {
let listener = TcpListener::bind("127.0.0.1:0").expect("bind ephemeral port");
let addr = listener.local_addr().expect("local_addr");
let captured = Arc::new(std::sync::Mutex::new(String::new()));
let captured_inner = captured.clone();
std::thread::spawn(move || {
if let Ok((mut stream, _)) = listener.accept() {
let _ = stream.set_read_timeout(Some(READ_TIMEOUT));
let request_bytes = consume_request_capturing(&mut stream);
*captured_inner.lock().unwrap() = String::from_utf8_lossy(&request_bytes).to_string();
let _ = stream.write_all(response.as_bytes());
let _ = stream.flush();
let _ = stream.shutdown(std::net::Shutdown::Both);
}
});
(addr, captured)
}
fn serve_one(mut stream: TcpStream, resp: &str) {
let _ = stream.set_read_timeout(Some(READ_TIMEOUT));
consume_request(&mut stream);
let _ = stream.write_all(resp.as_bytes());
let _ = stream.flush();
let _ = stream.shutdown(std::net::Shutdown::Both);
}
fn consume_request(stream: &mut TcpStream) {
let deadline = Instant::now() + REQUEST_READ_DEADLINE;
let mut accum: Vec<u8> = Vec::with_capacity(8 * 1024);
let mut chunk = [0u8; 8 * 1024];
let header_end = loop {
if Instant::now() >= deadline {
return;
}
match stream.read(&mut chunk) {
Ok(0) => return, Ok(n) => {
accum.extend_from_slice(&chunk[..n]);
if let Some(pos) = find_double_crlf(&accum) {
break pos + 4;
}
if accum.len() > 1 << 20 {
return;
}
}
Err(_) => return,
}
};
let content_length = parse_content_length(&accum[..header_end]);
let already_have = accum.len() - header_end;
let Some(total_body) = content_length else {
return;
};
if already_have >= total_body {
return;
}
let mut remaining = total_body - already_have;
while remaining > 0 {
if Instant::now() >= deadline {
return;
}
let want = remaining.min(chunk.len());
match stream.read(&mut chunk[..want]) {
Ok(0) => return, Ok(n) => {
remaining -= n;
}
Err(_) => return,
}
}
}
fn consume_request_capturing(stream: &mut TcpStream) -> Vec<u8> {
let deadline = Instant::now() + REQUEST_READ_DEADLINE;
let mut accum: Vec<u8> = Vec::with_capacity(8 * 1024);
let mut chunk = [0u8; 8 * 1024];
let header_end = loop {
if Instant::now() >= deadline {
return accum;
}
match stream.read(&mut chunk) {
Ok(0) => return accum,
Ok(n) => {
accum.extend_from_slice(&chunk[..n]);
if let Some(pos) = find_double_crlf(&accum) {
break pos + 4;
}
if accum.len() > 1 << 20 {
return accum;
}
}
Err(_) => return accum,
}
};
let content_length = parse_content_length(&accum[..header_end]);
let already_have = accum.len() - header_end;
let Some(total_body) = content_length else {
return accum;
};
if already_have >= total_body {
return accum;
}
let mut remaining = total_body - already_have;
while remaining > 0 {
if Instant::now() >= deadline {
return accum;
}
let want = remaining.min(chunk.len());
match stream.read(&mut chunk[..want]) {
Ok(0) => return accum,
Ok(n) => {
accum.extend_from_slice(&chunk[..n]);
remaining -= n;
}
Err(_) => return accum,
}
}
accum
}
fn find_double_crlf(buf: &[u8]) -> Option<usize> {
buf.windows(4).position(|w| w == b"\r\n\r\n")
}
fn parse_content_length(header_block: &[u8]) -> Option<usize> {
let as_str = std::str::from_utf8(header_block).ok()?;
for line in as_str.split("\r\n") {
let Some((name, value)) = line.split_once(':') else {
continue;
};
if name.trim().eq_ignore_ascii_case("content-length") {
return value.trim().parse::<usize>().ok();
}
}
None
}
#[cfg(test)]
mod self_tests {
use super::*;
#[test]
fn find_double_crlf_locates_header_terminator() {
let buf = b"GET / HTTP/1.1\r\nHost: x\r\n\r\nbody";
assert_eq!(find_double_crlf(buf), Some(23));
}
#[test]
fn find_double_crlf_returns_none_when_absent() {
assert_eq!(find_double_crlf(b"GET / HTTP/1.1\r\nHost: x\r\n"), None);
}
#[test]
fn parse_content_length_case_insensitive() {
let hdr = b"PUT / HTTP/1.1\r\nHost: x\r\nContent-Length: 42\r\n\r\n";
assert_eq!(parse_content_length(hdr), Some(42));
let hdr = b"PUT / HTTP/1.1\r\nHost: x\r\ncontent-length: 7\r\n\r\n";
assert_eq!(parse_content_length(hdr), Some(7));
}
#[test]
fn parse_content_length_missing_returns_none() {
let hdr = b"GET / HTTP/1.1\r\nHost: x\r\n\r\n";
assert_eq!(parse_content_length(hdr), None);
}
#[test]
fn parse_content_length_unparseable_returns_none() {
let hdr = b"PUT / HTTP/1.1\r\nContent-Length: chunked\r\n\r\n";
assert_eq!(parse_content_length(hdr), None);
}
#[test]
fn responder_consumes_full_body_before_responding() {
use std::io::{Read, Write};
use std::net::TcpStream;
let canned =
"HTTP/1.1 201 Created\r\nContent-Length: 2\r\nContent-Type: text/plain\r\n\r\nok";
let (addr, calls) = spawn_oneshot_http_responder(vec![canned]);
let body = vec![b'x'; 32 * 1024];
let body_len = body.len();
let request = format!(
"PUT /api/v2/package HTTP/1.1\r\nHost: 127.0.0.1\r\nContent-Length: {body_len}\r\n\r\n"
);
let mut stream = TcpStream::connect(addr).expect("connect");
stream
.set_read_timeout(Some(Duration::from_secs(5)))
.expect("read timeout");
stream.write_all(request.as_bytes()).expect("write headers");
stream.write_all(&body).expect("write body");
stream.flush().expect("flush");
let mut response = String::new();
stream
.read_to_string(&mut response)
.expect("read full response");
assert!(
response.starts_with("HTTP/1.1 201 Created"),
"unexpected response: {response:?}"
);
assert!(
response.ends_with("ok"),
"unexpected response: {response:?}"
);
assert_eq!(calls.load(Ordering::SeqCst), 1, "exactly one accept");
}
#[test]
fn capturing_responder_records_request_headers() {
use std::io::Write;
use std::net::TcpStream;
let canned = "HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n";
let (addr, captured) = spawn_request_capturing_responder(canned);
let request = "GET /search/issues HTTP/1.1\r\n\
Host: 127.0.0.1\r\n\
Authorization: Bearer secret-token\r\n\
Content-Length: 0\r\n\r\n";
let mut stream = TcpStream::connect(addr).expect("connect");
stream
.set_read_timeout(Some(Duration::from_secs(5)))
.expect("read timeout");
stream.write_all(request.as_bytes()).expect("write");
stream.flush().expect("flush");
std::thread::sleep(Duration::from_millis(50));
let _ = stream.shutdown(std::net::Shutdown::Both);
let deadline = Instant::now() + Duration::from_secs(2);
let captured_str = loop {
let s = captured.lock().unwrap().clone();
if !s.is_empty() || Instant::now() >= deadline {
break s;
}
std::thread::sleep(Duration::from_millis(10));
};
let lower = captured_str.to_ascii_lowercase();
assert!(
lower.contains("authorization: bearer secret-token"),
"captured request missing Authorization: {captured_str:?}"
);
}
}