use std::collections::HashMap;
use std::time::Duration;
use tokio::io::{
AsyncReadExt,
AsyncWriteExt,
};
use tokio::net::{
TcpListener,
TcpStream,
};
use tokio::sync::oneshot;
use url::Url;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CapturedRequest {
pub method: String,
pub target: String,
pub headers: HashMap<String, String>,
pub body: Vec<u8>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ResponseChunk {
pub delay: Duration,
pub bytes: Vec<u8>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ResponsePlan {
Immediate {
status: u16,
headers: Vec<(String, String)>,
body: Vec<u8>,
},
ImmediateRawHeaders {
status: u16,
headers: Vec<(String, Vec<u8>)>,
body: Vec<u8>,
},
DelayedStart {
delay: Duration,
status: u16,
headers: Vec<(String, String)>,
body: Vec<u8>,
},
PartialThenDelay {
status: u16,
headers: Vec<(String, String)>,
total_length: usize,
prefix: Vec<u8>,
delay: Duration,
},
Chunked {
status: u16,
headers: Vec<(String, String)>,
chunks: Vec<ResponseChunk>,
finish: bool,
},
}
#[derive(Debug)]
pub struct OneShotServer {
base_url: Url,
request_rx: oneshot::Receiver<CapturedRequest>,
join_handle: tokio::task::JoinHandle<()>,
}
#[derive(Debug)]
pub struct MultiShotServer {
base_url: Url,
request_rx: oneshot::Receiver<Vec<CapturedRequest>>,
join_handle: tokio::task::JoinHandle<()>,
}
impl OneShotServer {
pub fn base_url(&self) -> Url {
self.base_url.clone()
}
pub async fn finish(self) -> CapturedRequest {
let request = self
.request_rx
.await
.expect("one-shot test server dropped request sender");
self.join_handle
.await
.expect("one-shot test server task panicked");
request
}
}
impl MultiShotServer {
pub fn base_url(&self) -> Url {
self.base_url.clone()
}
pub async fn finish(self) -> Vec<CapturedRequest> {
let requests = self
.request_rx
.await
.expect("multi-shot test server dropped request sender");
self.join_handle
.await
.expect("multi-shot test server task panicked");
requests
}
}
pub async fn spawn_one_shot_server(plan: ResponsePlan) -> OneShotServer {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("failed to bind one-shot test server");
let addr = listener
.local_addr()
.expect("failed to query one-shot server local address");
let base_url = Url::parse(&format!("http://{addr}/")).expect("failed to build base URL");
let (request_tx, request_rx) = oneshot::channel::<CapturedRequest>();
let join_handle = tokio::spawn(async move {
let accept_result = listener.accept().await;
let (mut stream, _) = match accept_result {
Ok(result) => result,
Err(error) => panic!("one-shot test server failed to accept connection: {error}"),
};
let request = read_request(&mut stream)
.await
.expect("failed to read request in one-shot test server");
let _ = request_tx.send(request);
if let Err(error) = write_response(&mut stream, plan).await {
if !is_expected_client_disconnect(&error) {
panic!("failed to write response in one-shot test server: {error}");
}
}
});
OneShotServer {
base_url,
request_rx,
join_handle,
}
}
pub async fn spawn_multi_shot_server(plans: Vec<ResponsePlan>) -> MultiShotServer {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("failed to bind multi-shot test server");
let addr = listener
.local_addr()
.expect("failed to query multi-shot server local address");
let base_url = Url::parse(&format!("http://{addr}/")).expect("failed to build base URL");
let (request_tx, request_rx) = oneshot::channel::<Vec<CapturedRequest>>();
let join_handle = tokio::spawn(async move {
let mut handles = Vec::with_capacity(plans.len());
for (index, plan) in plans.into_iter().enumerate() {
let accept_result = listener.accept().await;
let (mut stream, _) = match accept_result {
Ok(result) => result,
Err(error) => panic!("multi-shot test server failed to accept connection: {error}"),
};
handles.push(tokio::spawn(async move {
let request = read_request(&mut stream)
.await
.expect("failed to read request in multi-shot test server");
if let Err(error) = write_response(&mut stream, plan).await {
if !is_expected_client_disconnect(&error) {
panic!("failed to write response in multi-shot test server: {error}");
}
}
(index, request)
}));
}
let mut requests = vec![None; handles.len()];
for handle in handles {
let (index, request) = handle.await.expect("multi-shot request task panicked");
requests[index] = Some(request);
}
let requests = requests
.into_iter()
.map(|request| request.expect("multi-shot request missing"))
.collect();
let _ = request_tx.send(requests);
});
MultiShotServer {
base_url,
request_rx,
join_handle,
}
}
async fn read_request(stream: &mut TcpStream) -> std::io::Result<CapturedRequest> {
let read_timeout = Duration::from_secs(3);
let mut buffer = Vec::new();
let header_end_index = loop {
let mut chunk = [0_u8; 1024];
let read_size = tokio::time::timeout(read_timeout, stream.read(&mut chunk))
.await
.map_err(|_| {
std::io::Error::new(
std::io::ErrorKind::TimedOut,
"timed out while waiting for request headers",
)
})??;
if read_size == 0 {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"connection closed before request headers were complete",
));
}
buffer.extend_from_slice(&chunk[..read_size]);
if let Some(index) = find_subsequence(&buffer, b"\r\n\r\n") {
break index + 4;
}
};
let header_bytes = &buffer[..header_end_index];
let body = buffer[header_end_index..].to_vec();
let header_text = String::from_utf8_lossy(header_bytes);
let mut lines = header_text.split("\r\n");
let request_line = lines.next().unwrap_or_default();
let mut request_parts = request_line.split_whitespace();
let method = request_parts.next().unwrap_or_default().to_string();
let target = request_parts.next().unwrap_or_default().to_string();
let mut headers = HashMap::new();
for line in lines {
if line.is_empty() {
break;
}
if let Some((name, value)) = line.split_once(':') {
headers.insert(name.trim().to_ascii_lowercase(), value.trim().to_string());
}
}
Ok(CapturedRequest {
method,
target,
headers,
body,
})
}
async fn write_response(stream: &mut TcpStream, plan: ResponsePlan) -> std::io::Result<()> {
match plan {
ResponsePlan::Immediate {
status,
headers,
body,
} => write_fixed_response(stream, status, headers, body).await?,
ResponsePlan::ImmediateRawHeaders {
status,
mut headers,
body,
} => {
if !contains_raw_header(&headers, "Content-Length") {
headers.push((
"Content-Length".to_string(),
body.len().to_string().into_bytes(),
));
}
write_status_and_raw_headers(stream, status, &headers).await?;
if !body.is_empty() {
stream.write_all(&body).await?;
}
stream.flush().await?;
}
ResponsePlan::DelayedStart {
delay,
status,
headers,
body,
} => {
tokio::time::sleep(delay).await;
write_fixed_response(stream, status, headers, body).await?;
}
ResponsePlan::PartialThenDelay {
status,
mut headers,
total_length,
prefix,
delay,
} => {
if !contains_header(&headers, "Content-Length") {
headers.push(("Content-Length".to_string(), total_length.to_string()));
}
write_status_and_headers(stream, status, &headers).await?;
stream.write_all(&prefix).await?;
stream.flush().await?;
tokio::time::sleep(delay).await;
}
ResponsePlan::Chunked {
status,
mut headers,
chunks,
finish,
} => {
if !contains_header(&headers, "Transfer-Encoding") {
headers.push(("Transfer-Encoding".to_string(), "chunked".to_string()));
}
write_status_and_headers(stream, status, &headers).await?;
for chunk in chunks {
if !chunk.delay.is_zero() {
tokio::time::sleep(chunk.delay).await;
}
let length_line = format!("{:X}\r\n", chunk.bytes.len());
stream.write_all(length_line.as_bytes()).await?;
stream.write_all(&chunk.bytes).await?;
stream.write_all(b"\r\n").await?;
stream.flush().await?;
}
if finish {
stream.write_all(b"0\r\n\r\n").await?;
stream.flush().await?;
}
}
}
Ok(())
}
async fn write_fixed_response(
stream: &mut TcpStream,
status: u16,
mut headers: Vec<(String, String)>,
body: Vec<u8>,
) -> std::io::Result<()> {
if !contains_header(&headers, "Content-Length") {
headers.push(("Content-Length".to_string(), body.len().to_string()));
}
write_status_and_headers(stream, status, &headers).await?;
if !body.is_empty() {
stream.write_all(&body).await?;
}
stream.flush().await?;
Ok(())
}
async fn write_status_and_headers(
stream: &mut TcpStream,
status: u16,
headers: &[(String, String)],
) -> std::io::Result<()> {
let mut head = format!("HTTP/1.1 {} {}\r\n", status, reason_phrase(status));
for (name, value) in headers {
head.push_str(name);
head.push_str(": ");
head.push_str(value);
head.push_str("\r\n");
}
head.push_str("\r\n");
stream.write_all(head.as_bytes()).await?;
Ok(())
}
fn contains_header(headers: &[(String, String)], name: &str) -> bool {
headers
.iter()
.any(|(header_name, _)| header_name.eq_ignore_ascii_case(name))
}
fn contains_raw_header(headers: &[(String, Vec<u8>)], name: &str) -> bool {
headers
.iter()
.any(|(header_name, _)| header_name.eq_ignore_ascii_case(name))
}
async fn write_status_and_raw_headers(
stream: &mut TcpStream,
status: u16,
headers: &[(String, Vec<u8>)],
) -> std::io::Result<()> {
let status_line = format!("HTTP/1.1 {} {}\r\n", status, reason_phrase(status));
stream.write_all(status_line.as_bytes()).await?;
for (name, value) in headers {
stream.write_all(name.as_bytes()).await?;
stream.write_all(b": ").await?;
stream.write_all(value).await?;
stream.write_all(b"\r\n").await?;
}
stream.write_all(b"\r\n").await?;
Ok(())
}
fn is_expected_client_disconnect(error: &std::io::Error) -> bool {
matches!(
error.kind(),
std::io::ErrorKind::BrokenPipe
| std::io::ErrorKind::ConnectionAborted
| std::io::ErrorKind::ConnectionReset
| std::io::ErrorKind::NotConnected
)
}
fn find_subsequence(haystack: &[u8], needle: &[u8]) -> Option<usize> {
haystack
.windows(needle.len())
.position(|window| window == needle)
}
fn reason_phrase(status: u16) -> &'static str {
match status {
200 => "OK",
201 => "Created",
204 => "No Content",
400 => "Bad Request",
401 => "Unauthorized",
403 => "Forbidden",
404 => "Not Found",
408 => "Request Timeout",
429 => "Too Many Requests",
500 => "Internal Server Error",
502 => "Bad Gateway",
503 => "Service Unavailable",
504 => "Gateway Timeout",
_ => "Unknown",
}
}