mini_http_test/
server.rs

1use std::{
2    net::SocketAddr,
3    sync::{Arc, Mutex},
4    time::Duration,
5};
6
7use async_trait::async_trait;
8use http_body_util::BodyExt;
9use hyper::{
10    body::{self, Body, Bytes},
11    server::conn::http1,
12    service::service_fn,
13    Request, Uri,
14};
15use tokio::{net::TcpListener, select, sync::watch, time::Instant};
16
17use crate::Error;
18
19use crate::{run_handler, Handler};
20
21/// Listens on a random port, running the given function to handle each request.
22///
23/// See the crate documentation for an example.
24#[derive(Debug, Clone)]
25pub struct Server {
26    close_tx: Arc<watch::Sender<u8>>,
27    addr: SocketAddr,
28    req_count: Arc<Mutex<u64>>,
29    concurrent_req_count: Arc<Mutex<u64>>,
30}
31
32impl Server {
33    /// Creates a new HTTP server on a random port running the given handler.
34    /// The handler will be called on every request, and the total request count
35    /// can be retrieved with [server.req_count()](Server::req_count).
36    ///
37    /// The server can be safely cloned and used from multiple threads. When the
38    /// final reference to the server is dropped, the server will be shut down
39    /// and all pending requests will be aborted. Aborting the server will
40    /// happen in the background and will not block.
41    ///
42    /// The remote socket address is added as a
43    /// [SocketAddr](std::net::SocketAddr)
44    /// [extension](hyper::Request::extensions) to the request object.
45    pub async fn new<H: Handler + Clone + Send + Sync + 'static>(
46        handler: H,
47    ) -> Result<Self, Error> {
48        let addr: SocketAddr = ([127, 0, 0, 1], 0).into();
49        let tcp_listener = TcpListener::bind(addr)
50            .await
51            .map_err(Error::BindTCPListener)?;
52        let addr = tcp_listener
53            .local_addr()
54            .map_err(Error::GetTCPListenerAddress)?;
55
56        let (close_tx, close_rx) = watch::channel::<u8>(0);
57        let req_count = Arc::new(Mutex::new(0));
58        let concurrent_req_count = Arc::new(Mutex::new(0));
59
60        {
61            let handler = handler.clone();
62            let req_count = req_count.clone();
63            let concurrent_req_count = concurrent_req_count.clone();
64
65            tokio::spawn(async move {
66                let mut close_rx = close_rx.clone();
67
68                loop {
69                    let (tcp_stream, remote_addr) = select! {
70                        _ = close_rx.changed() => {
71                            return;
72                        }
73                        res = tcp_listener.accept() => {
74                            match res {
75                                Ok(res) => res,
76                                Err(err) => {
77                                    eprintln!("Error while accepting TCP connection: {}", err);
78                                    return;
79                                }
80                            }
81                        }
82                    };
83
84                    let handler = handler.clone();
85                    let mut close_rx = close_rx.clone();
86                    let req_count = req_count.clone();
87                    let concurrent_req_count = concurrent_req_count.clone();
88                    tokio::spawn(async move {
89                        let handler = &handler;
90                        let req_count = &req_count;
91                        let concurrent_req_count = &concurrent_req_count;
92
93                        let service = service_fn(|mut req: Request<body::Incoming>| async move {
94                            *concurrent_req_count.lock().expect("lock poisoned") += 1;
95                            req.extensions_mut().insert(remote_addr);
96                            let res = run_handler(handler.clone(), req).await;
97                            *concurrent_req_count.lock().expect("lock poisoned") -= 1;
98                            *req_count.lock().expect("lock poisoned") += 1;
99                            res
100                        });
101
102                        let res = select! {
103                            _ = close_rx.changed() => {
104                                return;
105                            }
106                            res = http1::Builder::new()
107                                .keep_alive(true)
108                                .serve_connection(tcp_stream, service) => res,
109                        };
110
111                        if let Err(http_err) = res {
112                            eprintln!("Error while serving HTTP connection: {}", http_err);
113                        }
114                    });
115                }
116            });
117        };
118
119        Ok(Self {
120            close_tx: Arc::new(close_tx),
121            addr,
122            req_count,
123            concurrent_req_count,
124        })
125    }
126
127    /// Returns the socket address the server is listening on.
128    pub fn addr(&self) -> SocketAddr {
129        self.addr
130    }
131
132    /// Returns a valid request URL for the given path and query string.
133    pub fn url(&self, path_and_query: &str) -> Uri {
134        Uri::builder()
135            .scheme("http")
136            .authority(self.addr.to_string().as_str())
137            .path_and_query(path_and_query)
138            .build()
139            .expect("should be a valid URL")
140    }
141
142    /// Returns the number of requests handled by the server. This value is
143    /// incremented after the request handler has finished, but before the
144    /// response has been sent.
145    ///
146    /// At the end of tests, this should be asserted to be equal to the amount
147    /// of requests sent.
148    ///
149    /// Any panics in the request handler may result in the counter becoming out
150    /// of sync.
151    pub fn req_count(&self) -> u64 {
152        *self.req_count.lock().expect("lock poisoned")
153    }
154
155    /// Await req_count reaching a certain number. This polls every 10ms and
156    /// times out after the given duration.
157    pub async fn await_req_count(&self, target_count: u64, timeout: Duration) -> Result<(), Error> {
158        let start = Instant::now();
159        loop {
160            let current_count = self.req_count();
161            if current_count == target_count {
162                return Ok(());
163            }
164
165            if start.elapsed() > timeout {
166                return Err(Error::AwaitReqCountTimeout {
167                    current_count,
168                    target_count,
169                    timeout,
170                });
171            }
172
173            tokio::time::sleep(Duration::from_millis(10)).await;
174        }
175    }
176
177    /// Returns the number of concurrent requests currently being handled by the
178    /// server. A concurrent request is measured by incrementing the counter
179    /// before the request handler is called, and decrementing it after the
180    /// request handler has finished. The response may still be in the process
181    /// of being sent when the counter is decremented.
182    ///
183    /// Any panics in the request handler may result in the counter becoming out
184    /// of sync.
185    pub fn concurrent_req_count(&self) -> u64 {
186        *self.concurrent_req_count.lock().expect("lock poisoned")
187    }
188
189    /// Await concurrent_req_count reaching a certain number. This polls every
190    /// 10ms and times out after the given duration.
191    pub async fn await_concurrent_req_count(
192        &self,
193        target_count: u64,
194        timeout: Duration,
195    ) -> Result<(), Error> {
196        let start = Instant::now();
197        loop {
198            let current_count = self.concurrent_req_count();
199            if current_count == target_count {
200                return Ok(());
201            }
202
203            if start.elapsed() > timeout {
204                return Err(Error::AwaitConcurrentReqCountTimeout {
205                    current_count,
206                    target_count,
207                    timeout,
208                });
209            }
210
211            tokio::time::sleep(Duration::from_millis(10)).await;
212        }
213    }
214
215    /// close kills the server and aborts all pending requests. This does not
216    /// block for all requests to finish.
217    pub fn close(&self) {
218        self.close_tx.send(1).expect("failed to close server");
219    }
220}
221
222impl Drop for Server {
223    fn drop(&mut self) {
224        if Arc::strong_count(&self.close_tx) == 1 {
225            self.close();
226        }
227    }
228}
229
230/// A handy extension to [hyper::Request](hyper::Request) that allows for easily
231/// reading the request body as a single `Bytes` object.
232#[async_trait]
233pub trait GetRequestBody {
234    async fn body_bytes(self) -> Result<Bytes, hyper::Error>;
235}
236
237#[async_trait]
238impl<B> GetRequestBody for Request<B>
239where
240    B: Body<Data = Bytes> + Send + Sync + 'static,
241    <B as Body>::Error: Into<hyper::Error>,
242{
243    async fn body_bytes(self) -> Result<Bytes, hyper::Error> {
244        self.into_body()
245            .collect()
246            .await
247            .map(|full| full.to_bytes())
248            .map_err(|err| err.into())
249    }
250}
251
252#[cfg(test)]
253mod test {
254    use http_body_util::Full;
255    use hyper::{body::Bytes, Response};
256
257    use super::*;
258    use crate::handle_ok;
259
260    #[tokio::test]
261    async fn server_ok() {
262        async fn handler(
263            req: Request<body::Incoming>,
264        ) -> Result<Response<Full<Bytes>>, hyper::Error> {
265            let body = req.body_bytes().await?;
266
267            Ok(Response::new(Full::new(body)))
268        }
269
270        let server = Server::new(handler).await.expect("create server");
271
272        let client = reqwest::Client::new();
273
274        static ITERATIONS: u64 = 10;
275        for i in 0..ITERATIONS {
276            let res = client
277                .post(server.url("/").to_string())
278                .body(format!("hello world {}", i))
279                .send()
280                .await
281                .expect("send request");
282
283            assert_eq!(res.status(), 200);
284            assert_eq!(
285                res.text().await.expect("read response"),
286                format!("hello world {}", i)
287            );
288
289            assert_eq!(server.req_count(), i + 1);
290        }
291
292        assert_eq!(server.req_count(), ITERATIONS);
293    }
294
295    #[tokio::test]
296    async fn server_move_closure_copy() {
297        let val = 1234;
298        let server = Server::new(move |_: Request<body::Incoming>| async move {
299            handle_ok(Response::new(val.to_string().into()))
300        })
301        .await
302        .expect("create server");
303
304        let client = reqwest::Client::new();
305
306        static ITERATIONS: u64 = 10;
307        for i in 0..ITERATIONS {
308            let res = client
309                .get(server.url("/").to_string())
310                .send()
311                .await
312                .expect("send request");
313
314            assert_eq!(res.status(), 200);
315            assert_eq!(res.text().await.expect("read response"), val.to_string());
316
317            assert_eq!(server.req_count(), i + 1);
318        }
319
320        assert_eq!(server.req_count(), ITERATIONS);
321    }
322
323    #[tokio::test]
324    async fn server_move_closure_arc() {
325        let val = Arc::new(Mutex::new(1234));
326        let server = {
327            let val = val.clone();
328            Server::new(move |_: Request<body::Incoming>| async move {
329                let mut val = val.lock().expect("lock poisoned");
330                *val += 1;
331                handle_ok(Response::new(val.to_string().into()))
332            })
333            .await
334            .expect("create server")
335        };
336
337        let client = reqwest::Client::new();
338
339        static ITERATIONS: u64 = 10;
340        for i in 0..ITERATIONS {
341            let res = client
342                .get(server.url("/").to_string())
343                .send()
344                .await
345                .expect("send request");
346
347            assert_eq!(res.status(), 200);
348            assert_eq!(
349                res.text().await.expect("read response"),
350                val.lock().expect("lock poisoned").to_string()
351            );
352
353            assert_eq!(server.req_count(), i + 1);
354        }
355
356        assert_eq!(server.req_count(), ITERATIONS);
357    }
358
359    #[tokio::test]
360    async fn server_failure() {
361        async fn handler(_: Request<body::Incoming>) -> Result<Response<Full<Bytes>>, String> {
362            Err("Internal Server Error".to_string())
363        }
364
365        let server = Server::new(handler).await.expect("create server");
366
367        let client = reqwest::Client::new();
368
369        static ITERATIONS: u64 = 10;
370        for i in 0..ITERATIONS {
371            let res = client
372                .get(server.url("/").to_string())
373                .send()
374                .await
375                .expect("send request");
376
377            assert_eq!(res.status(), 500);
378            assert_eq!(
379                res.text().await.expect("read response"),
380                "Internal Server Error"
381            );
382
383            assert_eq!(server.req_count(), i + 1);
384        }
385
386        assert_eq!(server.req_count(), ITERATIONS);
387    }
388
389    #[tokio::test(flavor = "multi_thread", worker_threads = 16)]
390    async fn server_await_req_count() {
391        async fn handler(_: Request<body::Incoming>) -> Result<Response<Full<Bytes>>, String> {
392            Ok(Response::new("hello world".into()))
393        }
394
395        let server = Server::new(handler).await.expect("create server");
396
397        let client = reqwest::Client::new();
398
399        // Spawn tasks that will send requests to the server.
400        static ITERATIONS: u64 = 10;
401        let url = server.url("/").to_string();
402        let futures: Vec<tokio::task::JoinHandle<()>> = (0..ITERATIONS)
403            .map(|_| {
404                let client = client.clone();
405                let url = url.clone();
406
407                tokio::spawn(async move {
408                    let res = client.get(url).send().await.expect("send request");
409                    assert_eq!(res.status(), 200);
410                })
411            })
412            .collect();
413
414        server
415            .await_req_count(ITERATIONS, Duration::from_secs(1))
416            .await
417            .expect("requests finished");
418        assert_eq!(server.req_count(), ITERATIONS);
419
420        // Ensure all requests have finished.
421        for fut in futures {
422            fut.await.unwrap();
423        }
424    }
425
426    #[tokio::test(flavor = "multi_thread", worker_threads = 16)]
427    async fn server_long_requests_cancellation() {
428        async fn handler(_: Request<body::Incoming>) -> Result<Response<Full<Bytes>>, String> {
429            // Sleep for 10 seconds to simulate a long request.
430            tokio::time::sleep(Duration::from_secs(10)).await;
431            Ok(Response::new("hello world".into()))
432        }
433
434        let server = Server::new(handler).await.expect("create server");
435
436        let client = reqwest::Client::new();
437
438        // Spawn tasks that will send requests to the server.
439        static ITERATIONS: u64 = 10;
440        let url = server.url("/").to_string();
441        let futures: Vec<tokio::task::JoinHandle<Result<(), String>>> = (0..ITERATIONS)
442            .map(|_| {
443                let client = client.clone();
444                let url = url.clone();
445
446                tokio::spawn(async move {
447                    let res = client.get(url).send().await;
448                    match res {
449                        Ok(_) => Err("expected request to be canceled".to_string()),
450                        Err(_) => Ok(()),
451                    }
452                })
453            })
454            .collect();
455
456        server
457            .await_concurrent_req_count(ITERATIONS, Duration::from_secs(1))
458            .await
459            .expect("requests start");
460        assert_eq!(server.concurrent_req_count(), ITERATIONS);
461
462        // Drop the server and the requests should be canceled immediately.
463        let now = Instant::now();
464        drop(server);
465
466        // Wait for the requests to be canceled.
467        for fut in futures {
468            fut.await.unwrap().expect("request canceled");
469        }
470        assert!(now.elapsed() < Duration::from_secs(1));
471    }
472
473    #[tokio::test]
474    async fn server_keep_alive() {
475        let server = {
476            let last_socket_addr: Arc<Mutex<Option<SocketAddr>>> = Arc::new(Mutex::new(None));
477            Server::new(move |req: Request<body::Incoming>| async move {
478                let socket_addr = req.extensions().get::<SocketAddr>().unwrap();
479                let mut last_socket_addr = last_socket_addr.lock().expect("lock poisoned");
480                match *last_socket_addr {
481                    Some(last_socket_addr) => {
482                        assert_eq!(&last_socket_addr, socket_addr);
483                    }
484                    None => {
485                        *last_socket_addr = Some(*socket_addr);
486                    }
487                }
488
489                handle_ok(Response::new("hello world".into()))
490            })
491            .await
492            .expect("create server")
493        };
494
495        let client = reqwest::Client::new();
496
497        // Spawn 10 requests and ensure they all use the same socket.
498        static ITERATIONS: u64 = 10;
499        for i in 0..ITERATIONS {
500            let res = client
501                .get(server.url("/").to_string())
502                .send()
503                .await
504                .expect("send request");
505
506            assert_eq!(res.status(), 200);
507            assert_eq!(res.text().await.expect("read response"), "hello world");
508            assert_eq!(server.req_count(), i + 1);
509        }
510
511        assert_eq!(server.req_count(), ITERATIONS);
512    }
513}