1use std::{
2 net::SocketAddr,
3 sync::{Arc, Mutex},
4 time::Duration,
5};
6
7use async_trait::async_trait;
8use http_body_util::BodyExt;
9use hyper::{
10 body::{self, Body, Bytes},
11 server::conn::http1,
12 service::service_fn,
13 Request, Uri,
14};
15use tokio::{net::TcpListener, select, sync::watch, time::Instant};
16
17use crate::Error;
18
19use crate::{run_handler, Handler};
20
21#[derive(Debug, Clone)]
25pub struct Server {
26 close_tx: Arc<watch::Sender<u8>>,
27 addr: SocketAddr,
28 req_count: Arc<Mutex<u64>>,
29 concurrent_req_count: Arc<Mutex<u64>>,
30}
31
32impl Server {
33 pub async fn new<H: Handler + Clone + Send + Sync + 'static>(
46 handler: H,
47 ) -> Result<Self, Error> {
48 let addr: SocketAddr = ([127, 0, 0, 1], 0).into();
49 let tcp_listener = TcpListener::bind(addr)
50 .await
51 .map_err(Error::BindTCPListener)?;
52 let addr = tcp_listener
53 .local_addr()
54 .map_err(Error::GetTCPListenerAddress)?;
55
56 let (close_tx, close_rx) = watch::channel::<u8>(0);
57 let req_count = Arc::new(Mutex::new(0));
58 let concurrent_req_count = Arc::new(Mutex::new(0));
59
60 {
61 let handler = handler.clone();
62 let req_count = req_count.clone();
63 let concurrent_req_count = concurrent_req_count.clone();
64
65 tokio::spawn(async move {
66 let mut close_rx = close_rx.clone();
67
68 loop {
69 let (tcp_stream, remote_addr) = select! {
70 _ = close_rx.changed() => {
71 return;
72 }
73 res = tcp_listener.accept() => {
74 match res {
75 Ok(res) => res,
76 Err(err) => {
77 eprintln!("Error while accepting TCP connection: {}", err);
78 return;
79 }
80 }
81 }
82 };
83
84 let handler = handler.clone();
85 let mut close_rx = close_rx.clone();
86 let req_count = req_count.clone();
87 let concurrent_req_count = concurrent_req_count.clone();
88 tokio::spawn(async move {
89 let handler = &handler;
90 let req_count = &req_count;
91 let concurrent_req_count = &concurrent_req_count;
92
93 let service = service_fn(|mut req: Request<body::Incoming>| async move {
94 *concurrent_req_count.lock().expect("lock poisoned") += 1;
95 req.extensions_mut().insert(remote_addr);
96 let res = run_handler(handler.clone(), req).await;
97 *concurrent_req_count.lock().expect("lock poisoned") -= 1;
98 *req_count.lock().expect("lock poisoned") += 1;
99 res
100 });
101
102 let res = select! {
103 _ = close_rx.changed() => {
104 return;
105 }
106 res = http1::Builder::new()
107 .keep_alive(true)
108 .serve_connection(tcp_stream, service) => res,
109 };
110
111 if let Err(http_err) = res {
112 eprintln!("Error while serving HTTP connection: {}", http_err);
113 }
114 });
115 }
116 });
117 };
118
119 Ok(Self {
120 close_tx: Arc::new(close_tx),
121 addr,
122 req_count,
123 concurrent_req_count,
124 })
125 }
126
127 pub fn addr(&self) -> SocketAddr {
129 self.addr
130 }
131
132 pub fn url(&self, path_and_query: &str) -> Uri {
134 Uri::builder()
135 .scheme("http")
136 .authority(self.addr.to_string().as_str())
137 .path_and_query(path_and_query)
138 .build()
139 .expect("should be a valid URL")
140 }
141
142 pub fn req_count(&self) -> u64 {
152 *self.req_count.lock().expect("lock poisoned")
153 }
154
155 pub async fn await_req_count(&self, target_count: u64, timeout: Duration) -> Result<(), Error> {
158 let start = Instant::now();
159 loop {
160 let current_count = self.req_count();
161 if current_count == target_count {
162 return Ok(());
163 }
164
165 if start.elapsed() > timeout {
166 return Err(Error::AwaitReqCountTimeout {
167 current_count,
168 target_count,
169 timeout,
170 });
171 }
172
173 tokio::time::sleep(Duration::from_millis(10)).await;
174 }
175 }
176
177 pub fn concurrent_req_count(&self) -> u64 {
186 *self.concurrent_req_count.lock().expect("lock poisoned")
187 }
188
189 pub async fn await_concurrent_req_count(
192 &self,
193 target_count: u64,
194 timeout: Duration,
195 ) -> Result<(), Error> {
196 let start = Instant::now();
197 loop {
198 let current_count = self.concurrent_req_count();
199 if current_count == target_count {
200 return Ok(());
201 }
202
203 if start.elapsed() > timeout {
204 return Err(Error::AwaitConcurrentReqCountTimeout {
205 current_count,
206 target_count,
207 timeout,
208 });
209 }
210
211 tokio::time::sleep(Duration::from_millis(10)).await;
212 }
213 }
214
215 pub fn close(&self) {
218 self.close_tx.send(1).expect("failed to close server");
219 }
220}
221
222impl Drop for Server {
223 fn drop(&mut self) {
224 if Arc::strong_count(&self.close_tx) == 1 {
225 self.close();
226 }
227 }
228}
229
230#[async_trait]
233pub trait GetRequestBody {
234 async fn body_bytes(self) -> Result<Bytes, hyper::Error>;
235}
236
237#[async_trait]
238impl<B> GetRequestBody for Request<B>
239where
240 B: Body<Data = Bytes> + Send + Sync + 'static,
241 <B as Body>::Error: Into<hyper::Error>,
242{
243 async fn body_bytes(self) -> Result<Bytes, hyper::Error> {
244 self.into_body()
245 .collect()
246 .await
247 .map(|full| full.to_bytes())
248 .map_err(|err| err.into())
249 }
250}
251
252#[cfg(test)]
253mod test {
254 use http_body_util::Full;
255 use hyper::{body::Bytes, Response};
256
257 use super::*;
258 use crate::handle_ok;
259
260 #[tokio::test]
261 async fn server_ok() {
262 async fn handler(
263 req: Request<body::Incoming>,
264 ) -> Result<Response<Full<Bytes>>, hyper::Error> {
265 let body = req.body_bytes().await?;
266
267 Ok(Response::new(Full::new(body)))
268 }
269
270 let server = Server::new(handler).await.expect("create server");
271
272 let client = reqwest::Client::new();
273
274 static ITERATIONS: u64 = 10;
275 for i in 0..ITERATIONS {
276 let res = client
277 .post(server.url("/").to_string())
278 .body(format!("hello world {}", i))
279 .send()
280 .await
281 .expect("send request");
282
283 assert_eq!(res.status(), 200);
284 assert_eq!(
285 res.text().await.expect("read response"),
286 format!("hello world {}", i)
287 );
288
289 assert_eq!(server.req_count(), i + 1);
290 }
291
292 assert_eq!(server.req_count(), ITERATIONS);
293 }
294
295 #[tokio::test]
296 async fn server_move_closure_copy() {
297 let val = 1234;
298 let server = Server::new(move |_: Request<body::Incoming>| async move {
299 handle_ok(Response::new(val.to_string().into()))
300 })
301 .await
302 .expect("create server");
303
304 let client = reqwest::Client::new();
305
306 static ITERATIONS: u64 = 10;
307 for i in 0..ITERATIONS {
308 let res = client
309 .get(server.url("/").to_string())
310 .send()
311 .await
312 .expect("send request");
313
314 assert_eq!(res.status(), 200);
315 assert_eq!(res.text().await.expect("read response"), val.to_string());
316
317 assert_eq!(server.req_count(), i + 1);
318 }
319
320 assert_eq!(server.req_count(), ITERATIONS);
321 }
322
323 #[tokio::test]
324 async fn server_move_closure_arc() {
325 let val = Arc::new(Mutex::new(1234));
326 let server = {
327 let val = val.clone();
328 Server::new(move |_: Request<body::Incoming>| async move {
329 let mut val = val.lock().expect("lock poisoned");
330 *val += 1;
331 handle_ok(Response::new(val.to_string().into()))
332 })
333 .await
334 .expect("create server")
335 };
336
337 let client = reqwest::Client::new();
338
339 static ITERATIONS: u64 = 10;
340 for i in 0..ITERATIONS {
341 let res = client
342 .get(server.url("/").to_string())
343 .send()
344 .await
345 .expect("send request");
346
347 assert_eq!(res.status(), 200);
348 assert_eq!(
349 res.text().await.expect("read response"),
350 val.lock().expect("lock poisoned").to_string()
351 );
352
353 assert_eq!(server.req_count(), i + 1);
354 }
355
356 assert_eq!(server.req_count(), ITERATIONS);
357 }
358
359 #[tokio::test]
360 async fn server_failure() {
361 async fn handler(_: Request<body::Incoming>) -> Result<Response<Full<Bytes>>, String> {
362 Err("Internal Server Error".to_string())
363 }
364
365 let server = Server::new(handler).await.expect("create server");
366
367 let client = reqwest::Client::new();
368
369 static ITERATIONS: u64 = 10;
370 for i in 0..ITERATIONS {
371 let res = client
372 .get(server.url("/").to_string())
373 .send()
374 .await
375 .expect("send request");
376
377 assert_eq!(res.status(), 500);
378 assert_eq!(
379 res.text().await.expect("read response"),
380 "Internal Server Error"
381 );
382
383 assert_eq!(server.req_count(), i + 1);
384 }
385
386 assert_eq!(server.req_count(), ITERATIONS);
387 }
388
389 #[tokio::test(flavor = "multi_thread", worker_threads = 16)]
390 async fn server_await_req_count() {
391 async fn handler(_: Request<body::Incoming>) -> Result<Response<Full<Bytes>>, String> {
392 Ok(Response::new("hello world".into()))
393 }
394
395 let server = Server::new(handler).await.expect("create server");
396
397 let client = reqwest::Client::new();
398
399 static ITERATIONS: u64 = 10;
401 let url = server.url("/").to_string();
402 let futures: Vec<tokio::task::JoinHandle<()>> = (0..ITERATIONS)
403 .map(|_| {
404 let client = client.clone();
405 let url = url.clone();
406
407 tokio::spawn(async move {
408 let res = client.get(url).send().await.expect("send request");
409 assert_eq!(res.status(), 200);
410 })
411 })
412 .collect();
413
414 server
415 .await_req_count(ITERATIONS, Duration::from_secs(1))
416 .await
417 .expect("requests finished");
418 assert_eq!(server.req_count(), ITERATIONS);
419
420 for fut in futures {
422 fut.await.unwrap();
423 }
424 }
425
426 #[tokio::test(flavor = "multi_thread", worker_threads = 16)]
427 async fn server_long_requests_cancellation() {
428 async fn handler(_: Request<body::Incoming>) -> Result<Response<Full<Bytes>>, String> {
429 tokio::time::sleep(Duration::from_secs(10)).await;
431 Ok(Response::new("hello world".into()))
432 }
433
434 let server = Server::new(handler).await.expect("create server");
435
436 let client = reqwest::Client::new();
437
438 static ITERATIONS: u64 = 10;
440 let url = server.url("/").to_string();
441 let futures: Vec<tokio::task::JoinHandle<Result<(), String>>> = (0..ITERATIONS)
442 .map(|_| {
443 let client = client.clone();
444 let url = url.clone();
445
446 tokio::spawn(async move {
447 let res = client.get(url).send().await;
448 match res {
449 Ok(_) => Err("expected request to be canceled".to_string()),
450 Err(_) => Ok(()),
451 }
452 })
453 })
454 .collect();
455
456 server
457 .await_concurrent_req_count(ITERATIONS, Duration::from_secs(1))
458 .await
459 .expect("requests start");
460 assert_eq!(server.concurrent_req_count(), ITERATIONS);
461
462 let now = Instant::now();
464 drop(server);
465
466 for fut in futures {
468 fut.await.unwrap().expect("request canceled");
469 }
470 assert!(now.elapsed() < Duration::from_secs(1));
471 }
472
473 #[tokio::test]
474 async fn server_keep_alive() {
475 let server = {
476 let last_socket_addr: Arc<Mutex<Option<SocketAddr>>> = Arc::new(Mutex::new(None));
477 Server::new(move |req: Request<body::Incoming>| async move {
478 let socket_addr = req.extensions().get::<SocketAddr>().unwrap();
479 let mut last_socket_addr = last_socket_addr.lock().expect("lock poisoned");
480 match *last_socket_addr {
481 Some(last_socket_addr) => {
482 assert_eq!(&last_socket_addr, socket_addr);
483 }
484 None => {
485 *last_socket_addr = Some(*socket_addr);
486 }
487 }
488
489 handle_ok(Response::new("hello world".into()))
490 })
491 .await
492 .expect("create server")
493 };
494
495 let client = reqwest::Client::new();
496
497 static ITERATIONS: u64 = 10;
499 for i in 0..ITERATIONS {
500 let res = client
501 .get(server.url("/").to_string())
502 .send()
503 .await
504 .expect("send request");
505
506 assert_eq!(res.status(), 200);
507 assert_eq!(res.text().await.expect("read response"), "hello world");
508 assert_eq!(server.req_count(), i + 1);
509 }
510
511 assert_eq!(server.req_count(), ITERATIONS);
512 }
513}