1use std::{collections::HashMap, io, net::SocketAddr};
2
3use async_trait::async_trait;
4use axum::{
5 body::Bytes,
6 extract::Query,
7 http::{HeaderMap, Method, StatusCode, Uri},
8 response::{Html, IntoResponse},
9 Json, Router, Server,
10};
11use serde_json::Value;
12use tokio::{
13 spawn,
14 sync::{mpsc, oneshot},
15 task::JoinHandle,
16};
17use tracing::{error, warn};
18
19#[derive(Clone, Debug, Eq, PartialEq)]
25pub struct HttpRequest {
26 pub body: Vec<u8>,
27 pub headers: HashMap<String, String>,
28 pub method: String,
29 pub path: String,
30 pub query: HashMap<String, Vec<String>>,
31}
32
33#[derive(Clone, Debug, Eq, PartialEq)]
39pub enum HttpResponse {
40 Empty,
41 Html(String),
42 Json(Value),
43 Text(String),
44}
45
46#[async_trait]
54pub trait HttpServer: Send + Sync {
55 async fn next(&mut self) -> Option<HttpRequest>;
59
60 async fn stop(self);
62}
63
64pub struct DefaultHttpServer {
72 req_rx: mpsc::Receiver<HttpRequest>,
73 server: JoinHandle<()>,
74 stop_tx: oneshot::Sender<()>,
75}
76
77impl DefaultHttpServer {
78 pub async fn start(addr: &SocketAddr) -> io::Result<Self> {
82 Self::with_response(addr, HttpResponse::Empty).await
83 }
84
85 pub async fn with_response(addr: &SocketAddr, resp: HttpResponse) -> io::Result<Self> {
89 let (stop_tx, stop_rx) = oneshot::channel();
90 let (req_tx, req_rx) = mpsc::channel(1);
91 let app = Router::new().fallback(
92 move |method: Method,
93 uri: Uri,
94 Query(query): Query<Vec<(String, String)>>,
95 headers: HeaderMap,
96 body: Bytes| async move {
97 let mut req_headers = HashMap::new();
98 for (name, val) in headers {
99 let name = if let Some(name) = &name {
100 name.as_str()
101 } else {
102 warn!("request contains header with no name");
103 continue;
104 };
105 let val = match val.to_str() {
106 Ok(val) => val,
107 Err(err) => {
108 warn!(details = %err, header = name, "failed to decode header value");
109 continue;
110 }
111 };
112 req_headers.insert(name.into(), val.into());
113 }
114 let query = query.into_iter().fold(
115 HashMap::<String, Vec<String>>::new(),
116 |mut query, (key, val)| {
117 query.entry(key).or_default().push(val);
118 query
119 },
120 );
121 let req = HttpRequest {
122 body: body.to_vec(),
123 headers: req_headers,
124 method: method.to_string(),
125 path: uri.path().into(),
126 query,
127 };
128 req_tx.send(req).await.ok();
129 match resp {
130 HttpResponse::Empty => StatusCode::OK.into_response(),
131 HttpResponse::Html(html) => (StatusCode::OK, Html(html)).into_response(),
132 HttpResponse::Json(json) => (StatusCode::OK, Json(json)).into_response(),
133 HttpResponse::Text(text) => (StatusCode::OK, text).into_response(),
134 }
135 },
136 );
137 let server = Server::bind(addr)
138 .serve(app.into_make_service())
139 .with_graceful_shutdown(async {
140 stop_rx.await.ok();
141 });
142 let server = spawn(async {
143 if let Err(err) = server.await {
144 error!(details = %err, "failed to start server");
145 }
146 });
147 Ok(Self {
148 req_rx,
149 server,
150 stop_tx,
151 })
152 }
153}
154
155#[async_trait]
156impl HttpServer for DefaultHttpServer {
157 async fn next(&mut self) -> Option<HttpRequest> {
158 self.req_rx.recv().await
159 }
160
161 async fn stop(self) {
162 self.stop_tx.send(()).ok();
163 if let Err(err) = self.server.await {
164 error!(details = %err, "failed to stop server");
165 }
166 }
167}
168
169#[cfg(feature = "mock")]
172mockall::mock! {
173 pub HttpServer {}
179
180 #[async_trait]
181 impl HttpServer for HttpServer {
182 async fn next(&mut self) -> Option<HttpRequest>;
183 async fn stop(self);
184 }
185}
186
187#[cfg(test)]
190mod test {
191 use std::{
192 net::{Ipv4Addr, SocketAddrV4},
193 time::Duration,
194 };
195
196 use reqwest::{Client, Response};
197 use tokio::time::sleep;
198
199 use super::*;
200
201 mod default_http_server {
204 use super::*;
205
206 async fn run(port: u16, resp: HttpResponse) -> Response {
209 let expected = HttpRequest {
210 body: "abc".to_string().into_bytes(),
211 headers: HashMap::from_iter([
212 ("accept".into(), "*/*".into()),
213 ("content-length".into(), "3".into()),
214 ("host".into(), format!("localhost:{port}")),
215 ]),
216 method: "GET".into(),
217 path: "/a/b".into(),
218 query: HashMap::from_iter([("foo".into(), vec!["bar1".into(), "bar2".into()])]),
219 };
220 let addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port));
221 let mut server = DefaultHttpServer::with_response(&addr, resp)
222 .await
223 .expect("failed to start server");
224 sleep(Duration::from_secs(1)).await;
225 let client = Client::new();
226 let query: Vec<(String, String)> = expected
227 .query
228 .clone()
229 .into_iter()
230 .flat_map(|(key, val)| val.into_iter().map(move |val| (key.clone(), val)))
231 .collect();
232 let resp = client
233 .get(format!("http://localhost:{port}{}", expected.path))
234 .query(&query)
235 .body(expected.body.clone())
236 .send()
237 .await
238 .expect("failed to send request");
239 let status = resp.status();
240 if status != reqwest::StatusCode::OK {
241 let body = resp.text().await.expect("failed to read response body");
242 panic!("request failed with status {status}: {body}");
243 }
244 let req = server.next().await.expect("failed to receive request");
245 assert_eq!(req, expected);
246 server.stop().await;
247 client
248 .get(format!("http://localhost:{port}"))
249 .send()
250 .await
251 .expect_err("request should fail after server is stopped");
252 resp
253 }
254
255 #[tokio::test]
258 async fn empty() {
259 let resp = run(8000, HttpResponse::Empty).await;
260 let text = resp.text().await.expect("failed to read response body");
261 assert!(text.is_empty());
262 }
263
264 #[tokio::test]
265 async fn html() {
266 let expected = "<head></head>";
267 let resp = run(8001, HttpResponse::Html(expected.into())).await;
268 let text = resp.text().await.expect("failed to read response body");
269 assert_eq!(text, expected);
270 }
271
272 #[tokio::test]
273 async fn json() {
274 let expected = Value::String("val".into());
275 let resp = run(8002, HttpResponse::Json(expected.clone())).await;
276 let json: Value = resp.json().await.expect("failed to read response body");
277 assert_eq!(json, expected);
278 }
279
280 #[tokio::test]
281 async fn text() {
282 let expected = "val";
283 let resp = run(8003, HttpResponse::Text(expected.into())).await;
284 let text = resp.text().await.expect("failed to read response body");
285 assert_eq!(text, expected);
286 }
287 }
288}