use std::collections::BTreeMap;
use std::net::SocketAddr;
use std::sync::{Arc, Mutex};
use serde_json::Value as JsonValue;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::Notify;
use tokio::task::JoinHandle;
#[allow(unused_imports)]
pub(crate) use crate::triggers::test_util::clock::MockClock;
#[derive(Clone, Debug)]
pub(crate) struct FakeHttpRequest {
pub(crate) method: String,
pub(crate) path: String,
pub(crate) headers: BTreeMap<String, String>,
pub(crate) body: String,
}
impl FakeHttpRequest {
#[allow(dead_code)]
pub(crate) fn body_json(&self) -> Option<JsonValue> {
if self.body.is_empty() {
None
} else {
serde_json::from_str(&self.body).ok()
}
}
}
#[derive(Clone, Debug)]
pub(crate) struct FakeHttpResponse {
pub(crate) status: u16,
pub(crate) headers: Vec<(String, String)>,
pub(crate) body: Vec<u8>,
pub(crate) disconnect: bool,
}
impl FakeHttpResponse {
pub(crate) fn ok_json(body: &JsonValue) -> Self {
Self::status_json(200, body)
}
pub(crate) fn status_json(status: u16, body: &JsonValue) -> Self {
Self {
status,
headers: vec![("content-type".into(), "application/json".into())],
body: body.to_string().into_bytes(),
disconnect: false,
}
}
pub(crate) fn text(status: u16, body: impl Into<String>) -> Self {
Self {
status,
headers: vec![("content-type".into(), "text/plain".into())],
body: body.into().into_bytes(),
disconnect: false,
}
}
#[allow(dead_code)]
pub(crate) fn with_headers<I, K, V>(mut self, headers: I) -> Self
where
I: IntoIterator<Item = (K, V)>,
K: Into<String>,
V: Into<String>,
{
self.headers = headers
.into_iter()
.map(|(k, v)| (k.into(), v.into()))
.collect();
self
}
#[allow(dead_code)]
pub(crate) fn with_body(mut self, body: impl Into<Vec<u8>>) -> Self {
self.body = body.into();
self
}
#[allow(dead_code)]
pub(crate) fn disconnect() -> Self {
Self {
status: 0,
headers: Vec::new(),
body: Vec::new(),
disconnect: true,
}
}
}
pub(crate) struct FakeHttpServer {
addr: SocketAddr,
base_url: String,
requests: Arc<Mutex<Vec<FakeHttpRequest>>>,
shutdown: Arc<Notify>,
handle: Option<JoinHandle<()>>,
}
impl FakeHttpServer {
const DEFAULT_CAPACITY: usize = 1024;
#[allow(dead_code)]
pub(crate) async fn start<F>(label: &'static str, handler: F) -> Self
where
F: FnMut(usize, SocketAddr, &FakeHttpRequest) -> FakeHttpResponse + Send + 'static,
{
Self::start_with_capacity(label, Self::DEFAULT_CAPACITY, handler).await
}
pub(crate) async fn start_with_capacity<F>(
label: &'static str,
expected_requests: usize,
mut handler: F,
) -> Self
where
F: FnMut(usize, SocketAddr, &FakeHttpRequest) -> FakeHttpResponse + Send + 'static,
{
let listener = TcpListener::bind("127.0.0.1:0")
.await
.expect("bind fake http server");
let addr = listener.local_addr().expect("fake http server addr");
let base_url = format!("http://{addr}");
let requests: Arc<Mutex<Vec<FakeHttpRequest>>> = Arc::new(Mutex::new(Vec::new()));
let shutdown = Arc::new(Notify::new());
let requests_task = requests.clone();
let shutdown_task = shutdown.clone();
let handle = tokio::spawn(async move {
for index in 0..expected_requests {
let mut stream = tokio::select! {
_ = shutdown_task.notified() => return,
accept = listener.accept() => match accept {
Ok((stream, _)) => stream,
Err(error) => panic!("{label}: accept failed: {error}"),
},
};
let request = match read_request(&mut stream).await {
Some(request) => request,
None => continue,
};
requests_task
.lock()
.expect("fake http requests poisoned")
.push(request.clone());
let response = handler(index, addr, &request);
if let Err(error) = write_response(&mut stream, response).await {
eprintln!("{label}: write failed: {error}");
}
}
});
Self {
addr,
base_url,
requests,
shutdown,
handle: Some(handle),
}
}
#[allow(dead_code)]
pub(crate) async fn scripted(label: &'static str, script: Vec<FakeHttpResponse>) -> Self {
let capacity = script.len();
let mut iter = script.into_iter();
Self::start_with_capacity(label, capacity, move |_, _, _| {
iter.next()
.expect("fake http: pre-scripted responses exhausted")
})
.await
}
pub(crate) fn base_url(&self) -> &str {
&self.base_url
}
#[allow(dead_code)]
pub(crate) fn addr(&self) -> SocketAddr {
self.addr
}
#[allow(dead_code)]
pub(crate) fn requests(&self) -> Vec<FakeHttpRequest> {
self.requests
.lock()
.expect("fake http requests poisoned")
.clone()
}
#[allow(dead_code)]
pub(crate) fn assert_received(
&self,
predicate: impl Fn(&FakeHttpRequest) -> bool,
) -> FakeHttpRequest {
let snapshot = self.requests();
snapshot
.into_iter()
.find(predicate)
.expect("fake http: no captured request matched the predicate")
}
}
impl Drop for FakeHttpServer {
fn drop(&mut self) {
self.shutdown.notify_waiters();
if let Some(handle) = self.handle.take() {
handle.abort();
}
}
}
async fn read_request(stream: &mut TcpStream) -> Option<FakeHttpRequest> {
let mut buffer: Vec<u8> = Vec::new();
let mut temp = [0u8; 4096];
let header_end = loop {
let n = match stream.read(&mut temp).await {
Ok(n) => n,
Err(_) => return None,
};
if n == 0 {
return None;
}
buffer.extend_from_slice(&temp[..n]);
if let Some(idx) = find_double_crlf(&buffer) {
break idx + 4;
}
};
let header_text = String::from_utf8_lossy(&buffer[..header_end]).to_string();
let mut lines = header_text.split("\r\n").filter(|line| !line.is_empty());
let request_line = lines.next()?;
let mut request_parts = request_line.split_whitespace();
let method = request_parts.next()?.to_string();
let path = request_parts.next()?.to_string();
let mut headers = BTreeMap::new();
for line in lines {
if let Some((name, value)) = line.split_once(':') {
headers.insert(name.to_ascii_lowercase(), value.trim().to_string());
}
}
let content_length = headers
.get("content-length")
.and_then(|value| value.parse::<usize>().ok())
.unwrap_or(0);
while buffer.len() < header_end + content_length {
let n = match stream.read(&mut temp).await {
Ok(n) => n,
Err(_) => return None,
};
if n == 0 {
return None;
}
buffer.extend_from_slice(&temp[..n]);
}
let body =
String::from_utf8_lossy(&buffer[header_end..header_end + content_length]).to_string();
Some(FakeHttpRequest {
method,
path,
headers,
body,
})
}
fn find_double_crlf(buffer: &[u8]) -> Option<usize> {
buffer.windows(4).position(|window| window == b"\r\n\r\n")
}
async fn write_response(stream: &mut TcpStream, response: FakeHttpResponse) -> std::io::Result<()> {
if response.disconnect {
return Ok(());
}
let status_text = status_reason(response.status);
let mut header_block = format!(
"HTTP/1.1 {} {}\r\ncontent-length: {}\r\nconnection: close\r\n",
response.status,
status_text,
response.body.len(),
);
for (name, value) in &response.headers {
header_block.push_str(name);
header_block.push_str(": ");
header_block.push_str(value);
header_block.push_str("\r\n");
}
header_block.push_str("\r\n");
stream.write_all(header_block.as_bytes()).await?;
stream.write_all(&response.body).await?;
stream.flush().await?;
Ok(())
}
fn status_reason(status: u16) -> &'static str {
match status {
200 => "OK",
201 => "Created",
202 => "Accepted",
204 => "No Content",
301 => "Moved Permanently",
302 => "Found",
304 => "Not Modified",
400 => "Bad Request",
401 => "Unauthorized",
403 => "Forbidden",
404 => "Not Found",
409 => "Conflict",
422 => "Unprocessable Entity",
429 => "Too Many Requests",
500 => "Internal Server Error",
502 => "Bad Gateway",
503 => "Service Unavailable",
_ => "OK",
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
async fn http_get(url: &str) -> (u16, String) {
let response = reqwest::get(url).await.expect("get");
let status = response.status().as_u16();
let body = response.text().await.expect("body");
(status, body)
}
#[tokio::test]
async fn handler_serves_one_request_and_captures_it() {
let server =
FakeHttpServer::start_with_capacity("fake-test", 1, |_index, _addr, request| {
assert_eq!(request.method, "GET");
FakeHttpResponse::ok_json(&json!({"ok": true, "path": request.path}))
})
.await;
let url = format!("{}/probe", server.base_url());
let (status, body) = http_get(&url).await;
assert_eq!(status, 200);
assert!(body.contains("\"path\":\"/probe\""));
assert_eq!(server.requests().len(), 1);
let captured = server.assert_received(|req| req.path == "/probe");
assert_eq!(captured.method, "GET");
}
#[tokio::test]
async fn scripted_responses_replay_in_order() {
let server = FakeHttpServer::scripted(
"scripted",
vec![
FakeHttpResponse::ok_json(&json!({"index": 0})),
FakeHttpResponse::status_json(429, &json!({"index": 1})),
],
)
.await;
let url = format!("{}/", server.base_url());
let (status_a, body_a) = http_get(&url).await;
let (status_b, body_b) = http_get(&url).await;
assert_eq!(status_a, 200);
assert!(body_a.contains("\"index\":0"));
assert_eq!(status_b, 429);
assert!(body_b.contains("\"index\":1"));
}
#[tokio::test]
async fn body_json_decodes_request_payload() {
let server = FakeHttpServer::start_with_capacity("body-json", 1, |_, _, request| {
let payload = request.body_json().expect("json body");
assert_eq!(payload["query"], json!("ping"));
FakeHttpResponse::ok_json(&json!({"echo": payload}))
})
.await;
let url = format!("{}/echo", server.base_url());
let response = reqwest::Client::new()
.post(&url)
.json(&json!({"query": "ping"}))
.send()
.await
.expect("post");
assert_eq!(response.status().as_u16(), 200);
let body = response.text().await.expect("text");
assert!(body.contains("\"echo\":{\"query\":\"ping\"}"));
}
}