mockable/
http.rs

1use std::{collections::HashMap, io, net::SocketAddr};
2
3use async_trait::async_trait;
4use axum::{
5    body::Bytes,
6    extract::Query,
7    http::{HeaderMap, Method, StatusCode, Uri},
8    response::{Html, IntoResponse},
9    Json, Router, Server,
10};
11use serde_json::Value;
12use tokio::{
13    spawn,
14    sync::{mpsc, oneshot},
15    task::JoinHandle,
16};
17use tracing::{error, warn};
18
19// HttpRequest
20
21/// HTTP request.
22///
23/// **This is supported on `feature=http` only.**
24#[derive(Clone, Debug, Eq, PartialEq)]
25pub struct HttpRequest {
26    pub body: Vec<u8>,
27    pub headers: HashMap<String, String>,
28    pub method: String,
29    pub path: String,
30    pub query: HashMap<String, Vec<String>>,
31}
32
33// HttpResponse
34
35/// HTTP response.
36///
37/// **This is supported on `feature=http` only.**
38#[derive(Clone, Debug, Eq, PartialEq)]
39pub enum HttpResponse {
40    Empty,
41    Html(String),
42    Json(Value),
43    Text(String),
44}
45
46// HttpServer
47
48/// Simple HTTP server that listen all requests.
49///
50/// **This is supported on `feature=http` only.**
51///
52/// [Example](https://github.com/leroyguillaume/mockable/tree/main/examples/http.rs).
53#[async_trait]
54pub trait HttpServer: Send + Sync {
55    /// Returns the next request received by the server.
56    ///
57    /// `None` is returned if the server is stopped.
58    async fn next(&mut self) -> Option<HttpRequest>;
59
60    /// Stops the server.
61    async fn stop(self);
62}
63
64// DefaultHttpServer
65
66/// Default implementation of [`HttpServer`](trait.HttpServer.html).
67///
68/// **This is supported on `feature=http` only.**
69///
70/// [Example](https://github.com/leroyguillaume/mockable/tree/main/examples/http.rs).
71pub struct DefaultHttpServer {
72    req_rx: mpsc::Receiver<HttpRequest>,
73    server: JoinHandle<()>,
74    stop_tx: oneshot::Sender<()>,
75}
76
77impl DefaultHttpServer {
78    /// Starts a new server listening on the given address.
79    ///
80    /// The server will respond status code 200 with an empty response to all requests.
81    pub async fn start(addr: &SocketAddr) -> io::Result<Self> {
82        Self::with_response(addr, HttpResponse::Empty).await
83    }
84
85    /// Starts a new server listening on the given address.
86    ///
87    /// The server will respond status code 200 with the given one to all requests.
88    pub async fn with_response(addr: &SocketAddr, resp: HttpResponse) -> io::Result<Self> {
89        let (stop_tx, stop_rx) = oneshot::channel();
90        let (req_tx, req_rx) = mpsc::channel(1);
91        let app = Router::new().fallback(
92            move |method: Method,
93                  uri: Uri,
94                  Query(query): Query<Vec<(String, String)>>,
95                  headers: HeaderMap,
96                  body: Bytes| async move {
97                let mut req_headers = HashMap::new();
98                for (name, val) in headers {
99                    let name = if let Some(name) = &name {
100                        name.as_str()
101                    } else {
102                        warn!("request contains header with no name");
103                        continue;
104                    };
105                    let val = match val.to_str() {
106                        Ok(val) => val,
107                        Err(err) => {
108                            warn!(details = %err, header = name, "failed to decode header value");
109                            continue;
110                        }
111                    };
112                    req_headers.insert(name.into(), val.into());
113                }
114                let query = query.into_iter().fold(
115                    HashMap::<String, Vec<String>>::new(),
116                    |mut query, (key, val)| {
117                        query.entry(key).or_default().push(val);
118                        query
119                    },
120                );
121                let req = HttpRequest {
122                    body: body.to_vec(),
123                    headers: req_headers,
124                    method: method.to_string(),
125                    path: uri.path().into(),
126                    query,
127                };
128                req_tx.send(req).await.ok();
129                match resp {
130                    HttpResponse::Empty => StatusCode::OK.into_response(),
131                    HttpResponse::Html(html) => (StatusCode::OK, Html(html)).into_response(),
132                    HttpResponse::Json(json) => (StatusCode::OK, Json(json)).into_response(),
133                    HttpResponse::Text(text) => (StatusCode::OK, text).into_response(),
134                }
135            },
136        );
137        let server = Server::bind(addr)
138            .serve(app.into_make_service())
139            .with_graceful_shutdown(async {
140                stop_rx.await.ok();
141            });
142        let server = spawn(async {
143            if let Err(err) = server.await {
144                error!(details = %err, "failed to start server");
145            }
146        });
147        Ok(Self {
148            req_rx,
149            server,
150            stop_tx,
151        })
152    }
153}
154
155#[async_trait]
156impl HttpServer for DefaultHttpServer {
157    async fn next(&mut self) -> Option<HttpRequest> {
158        self.req_rx.recv().await
159    }
160
161    async fn stop(self) {
162        self.stop_tx.send(()).ok();
163        if let Err(err) = self.server.await {
164            error!(details = %err, "failed to stop server");
165        }
166    }
167}
168
169// MockHttpServer
170
171#[cfg(feature = "mock")]
172mockall::mock! {
173    /// `mockall` implementation of [`HttpServer`](trait.HttpServer.html).
174    ///
175    /// **This is supported on `feature=http,mock` only.**
176    ///
177    /// [Example](https://github.com/leroyguillaume/mockable/tree/main/examples/http.rs).
178    pub HttpServer {}
179
180    #[async_trait]
181    impl HttpServer for HttpServer {
182        async fn next(&mut self) -> Option<HttpRequest>;
183        async fn stop(self);
184    }
185}
186
187// Tests
188
189#[cfg(test)]
190mod test {
191    use std::{
192        net::{Ipv4Addr, SocketAddrV4},
193        time::Duration,
194    };
195
196    use reqwest::{Client, Response};
197    use tokio::time::sleep;
198
199    use super::*;
200
201    // Mods
202
203    mod default_http_server {
204        use super::*;
205
206        // run
207
208        async fn run(port: u16, resp: HttpResponse) -> Response {
209            let expected = HttpRequest {
210                body: "abc".to_string().into_bytes(),
211                headers: HashMap::from_iter([
212                    ("accept".into(), "*/*".into()),
213                    ("content-length".into(), "3".into()),
214                    ("host".into(), format!("localhost:{port}")),
215                ]),
216                method: "GET".into(),
217                path: "/a/b".into(),
218                query: HashMap::from_iter([("foo".into(), vec!["bar1".into(), "bar2".into()])]),
219            };
220            let addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port));
221            let mut server = DefaultHttpServer::with_response(&addr, resp)
222                .await
223                .expect("failed to start server");
224            sleep(Duration::from_secs(1)).await;
225            let client = Client::new();
226            let query: Vec<(String, String)> = expected
227                .query
228                .clone()
229                .into_iter()
230                .flat_map(|(key, val)| val.into_iter().map(move |val| (key.clone(), val)))
231                .collect();
232            let resp = client
233                .get(format!("http://localhost:{port}{}", expected.path))
234                .query(&query)
235                .body(expected.body.clone())
236                .send()
237                .await
238                .expect("failed to send request");
239            let status = resp.status();
240            if status != reqwest::StatusCode::OK {
241                let body = resp.text().await.expect("failed to read response body");
242                panic!("request failed with status {status}: {body}");
243            }
244            let req = server.next().await.expect("failed to receive request");
245            assert_eq!(req, expected);
246            server.stop().await;
247            client
248                .get(format!("http://localhost:{port}"))
249                .send()
250                .await
251                .expect_err("request should fail after server is stopped");
252            resp
253        }
254
255        // Tests
256
257        #[tokio::test]
258        async fn empty() {
259            let resp = run(8000, HttpResponse::Empty).await;
260            let text = resp.text().await.expect("failed to read response body");
261            assert!(text.is_empty());
262        }
263
264        #[tokio::test]
265        async fn html() {
266            let expected = "<head></head>";
267            let resp = run(8001, HttpResponse::Html(expected.into())).await;
268            let text = resp.text().await.expect("failed to read response body");
269            assert_eq!(text, expected);
270        }
271
272        #[tokio::test]
273        async fn json() {
274            let expected = Value::String("val".into());
275            let resp = run(8002, HttpResponse::Json(expected.clone())).await;
276            let json: Value = resp.json().await.expect("failed to read response body");
277            assert_eq!(json, expected);
278        }
279
280        #[tokio::test]
281        async fn text() {
282            let expected = "val";
283            let resp = run(8003, HttpResponse::Text(expected.into())).await;
284            let text = resp.text().await.expect("failed to read response body");
285            assert_eq!(text, expected);
286        }
287    }
288}