use std::{collections::HashMap, io, net::SocketAddr};
use async_trait::async_trait;
use axum::{
body::Bytes,
extract::Query,
http::{HeaderMap, Method, StatusCode, Uri},
response::{Html, IntoResponse},
Json, Router, Server,
};
use serde_json::Value;
use tokio::{
spawn,
sync::{mpsc, oneshot},
task::JoinHandle,
};
use tracing::{error, warn};
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct HttpRequest {
pub body: Vec<u8>,
pub headers: HashMap<String, String>,
pub method: String,
pub path: String,
pub query: HashMap<String, Vec<String>>,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum HttpResponse {
Empty,
Html(String),
Json(Value),
Text(String),
}
#[async_trait]
pub trait HttpServer: Send + Sync {
async fn next(&mut self) -> Option<HttpRequest>;
async fn stop(self);
}
pub struct DefaultHttpServer {
req_rx: mpsc::Receiver<HttpRequest>,
server: JoinHandle<()>,
stop_tx: oneshot::Sender<()>,
}
impl DefaultHttpServer {
pub async fn start(addr: &SocketAddr) -> io::Result<Self> {
Self::with_response(addr, HttpResponse::Empty).await
}
pub async fn with_response(addr: &SocketAddr, resp: HttpResponse) -> io::Result<Self> {
let (stop_tx, stop_rx) = oneshot::channel();
let (req_tx, req_rx) = mpsc::channel(1);
let app = Router::new().fallback(
move |method: Method,
uri: Uri,
Query(query): Query<Vec<(String, String)>>,
headers: HeaderMap,
body: Bytes| async move {
let mut req_headers = HashMap::new();
for (name, val) in headers {
let name = if let Some(name) = &name {
name.as_str()
} else {
warn!("request contains header with no name");
continue;
};
let val = match val.to_str() {
Ok(val) => val,
Err(err) => {
warn!(details = %err, header = name, "failed to decode header value");
continue;
}
};
req_headers.insert(name.into(), val.into());
}
let query = query.into_iter().fold(
HashMap::<String, Vec<String>>::new(),
|mut query, (key, val)| {
query.entry(key).or_default().push(val);
query
},
);
let req = HttpRequest {
body: body.to_vec(),
headers: req_headers,
method: method.to_string(),
path: uri.path().into(),
query,
};
req_tx.send(req).await.ok();
match resp {
HttpResponse::Empty => StatusCode::OK.into_response(),
HttpResponse::Html(html) => (StatusCode::OK, Html(html)).into_response(),
HttpResponse::Json(json) => (StatusCode::OK, Json(json)).into_response(),
HttpResponse::Text(text) => (StatusCode::OK, text).into_response(),
}
},
);
let server = Server::bind(addr)
.serve(app.into_make_service())
.with_graceful_shutdown(async {
stop_rx.await.ok();
});
let server = spawn(async {
if let Err(err) = server.await {
error!(details = %err, "failed to start server");
}
});
Ok(Self {
req_rx,
server,
stop_tx,
})
}
}
#[async_trait]
impl HttpServer for DefaultHttpServer {
async fn next(&mut self) -> Option<HttpRequest> {
self.req_rx.recv().await
}
async fn stop(self) {
self.stop_tx.send(()).ok();
if let Err(err) = self.server.await {
error!(details = %err, "failed to stop server");
}
}
}
#[cfg(feature = "mock")]
mockall::mock! {
pub HttpServer {}
#[async_trait]
impl HttpServer for HttpServer {
async fn next(&mut self) -> Option<HttpRequest>;
async fn stop(self);
}
}
#[cfg(test)]
mod test {
use std::{
net::{Ipv4Addr, SocketAddrV4},
time::Duration,
};
use reqwest::{Client, Response};
use tokio::time::sleep;
use super::*;
mod default_http_server {
use super::*;
async fn run(port: u16, resp: HttpResponse) -> Response {
let expected = HttpRequest {
body: "abc".to_string().into_bytes(),
headers: HashMap::from_iter([
("accept".into(), "*/*".into()),
("content-length".into(), "3".into()),
("host".into(), format!("localhost:{port}")),
]),
method: "GET".into(),
path: "/a/b".into(),
query: HashMap::from_iter([("foo".into(), vec!["bar1".into(), "bar2".into()])]),
};
let addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port));
let mut server = DefaultHttpServer::with_response(&addr, resp)
.await
.expect("failed to start server");
sleep(Duration::from_secs(1)).await;
let client = Client::new();
let query: Vec<(String, String)> = expected
.query
.clone()
.into_iter()
.flat_map(|(key, val)| val.into_iter().map(move |val| (key.clone(), val)))
.collect();
let resp = client
.get(format!("http://localhost:{port}{}", expected.path))
.query(&query)
.body(expected.body.clone())
.send()
.await
.expect("failed to send request");
let status = resp.status();
if status != reqwest::StatusCode::OK {
let body = resp.text().await.expect("failed to read response body");
panic!("request failed with status {status}: {body}");
}
let req = server.next().await.expect("failed to receive request");
assert_eq!(req, expected);
server.stop().await;
client
.get(format!("http://localhost:{port}"))
.send()
.await
.expect_err("request should fail after server is stopped");
resp
}
#[tokio::test]
async fn empty() {
let resp = run(8000, HttpResponse::Empty).await;
let text = resp.text().await.expect("failed to read response body");
assert!(text.is_empty());
}
#[tokio::test]
async fn html() {
let expected = "<head></head>";
let resp = run(8001, HttpResponse::Html(expected.into())).await;
let text = resp.text().await.expect("failed to read response body");
assert_eq!(text, expected);
}
#[tokio::test]
async fn json() {
let expected = Value::String("val".into());
let resp = run(8002, HttpResponse::Json(expected.clone())).await;
let json: Value = resp.json().await.expect("failed to read response body");
assert_eq!(json, expected);
}
#[tokio::test]
async fn text() {
let expected = "val";
let resp = run(8003, HttpResponse::Text(expected.into())).await;
let text = resp.text().await.expect("failed to read response body");
assert_eq!(text, expected);
}
}
}