1use hyper::{Body, Request, Response};
2use tokio::net::{TcpListener, TcpStream};
3use tokio_noise::{
4 handshakes::{nn_psk2::Responder, NNpsk2},
5 NoiseTcpStream,
6};
7
8use std::{error::Error, future::Future, net::SocketAddr, time::Duration};
9
10use crate::ServerError;
11
12pub async fn serve_http<Psk, P, H, F, E>(
13 tcp_stream: TcpStream,
14 mut responder: Responder<P, Psk>,
15 mut handle_request: H,
16 handler_timeout: Option<Duration>,
17) -> Result<(), ServerError>
18where
19 P: FnMut(&[u8]) -> Option<Psk>,
20 Psk: AsRef<[u8]>,
21 H: FnMut(&[u8], Request<Body>) -> F,
22 F: 'static + Send + Future<Output = Result<Response<Body>, E>>,
23 E: Into<Box<dyn Error + Send + Sync>>,
24{
25 let timeout = handler_timeout.unwrap_or(Duration::from_secs(999999999));
26 tokio::time::timeout(timeout, async move {
27 let handshake = NNpsk2::new(&mut responder);
28 let noise_stream = NoiseTcpStream::handshake_responder(tcp_stream, handshake).await?;
29
30 let peer_identity = responder
31 .initiator_identity()
32 .expect("initiator identity is always set after successful handshake")
33 .to_owned();
34
35 let http_service =
36 hyper::service::service_fn(move |req| handle_request(&peer_identity, req));
37
38 hyper::server::conn::Http::new()
39 .serve_connection(noise_stream, http_service)
40 .await?;
41 Ok(())
42 })
43 .await
44 .map_err(|_| ServerError::HandlerTimeout)?
45}
46
47pub async fn accept_and_serve_http<Psk, P, M1, M2, Svc, F, E>(
48 listener: TcpListener,
49 mut make_responder: M1,
50 mut make_handle_request: M2,
51 timeout: Option<Duration>,
52) -> Result<(), std::io::Error>
53where
54 M1: FnMut(SocketAddr) -> Responder<P, Psk>,
55 P: 'static + Send + FnMut(&[u8]) -> Option<Psk>,
56 Psk: 'static + Send + Sync + AsRef<[u8]>,
57 M2: FnMut(SocketAddr) -> Svc,
58 Svc: 'static + Send + FnMut(&[u8], Request<Body>) -> F,
59 F: 'static + Send + Future<Output = Result<Response<Body>, E>>,
60 E: Into<Box<dyn Error + Send + Sync>>,
61{
62 loop {
63 let (tcp_stream, remote_addr) = match listener.accept().await {
64 Ok(s) => s,
65 Err(e) => return Err(e)?,
66 };
67
68 let responder = make_responder(remote_addr);
69 let handle_request: Svc = make_handle_request(remote_addr);
70
71 tokio::task::spawn(async move {
72 let result = serve_http(tcp_stream, responder, handle_request, timeout).await;
73
74 if let Err(e) = result {
75 log::warn!(
76 "failed to serve HTTP request from {} over noise: {}",
77 remote_addr,
78 e
79 );
80 }
81 });
82 }
83}
84
85#[cfg(test)]
86mod tests {
87 use super::*;
88 use hyper::{Body, Request, Response};
89 use std::{
90 collections::HashMap,
91 convert::Infallible,
92 sync::{Arc, Mutex},
93 };
94
95 #[ignore]
96 #[tokio::test]
97 async fn compiles_with_closure() {
98 let peers = Arc::new(HashMap::<Vec<u8>, [u8; 32]>::from([
99 (Vec::from(b"alice"), [0u8; 32]),
100 (Vec::from(b"bob"), [1u8; 32]),
101 (Vec::from(b"charlie"), [2u8; 32]),
102 ]));
103
104 let make_responder = move |_| {
105 let peers = peers.clone();
106 Responder::new(move |id| peers.get(id).cloned())
107 };
108
109 let make_handle_request = |_| {
110 |peer_id: &[u8], _req: Request<Body>| async move {
111 let _ = peer_id;
112 let resp = Response::new(Body::empty());
113 Ok::<_, Infallible>(resp)
114 }
115 };
116
117 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
118 accept_and_serve_http(listener, make_responder, make_handle_request, None)
119 .await
120 .unwrap();
121 }
122
123 #[ignore]
124 #[tokio::test]
125 async fn compiles_with_fn() {
126 let peers = Arc::new(HashMap::<Vec<u8>, [u8; 32]>::from([
127 (Vec::from(b"alice"), [0u8; 32]),
128 (Vec::from(b"bob"), [1u8; 32]),
129 (Vec::from(b"charlie"), [2u8; 32]),
130 ]));
131
132 let make_responder = move |_| {
133 let peers = peers.clone();
134 Responder::new(move |id| peers.get(id).cloned())
135 };
136
137 async fn handle_request(
138 _peer_name: &[u8],
139 _req: Request<Body>,
140 ) -> Result<Response<Body>, Infallible> {
141 let resp = Response::new(Body::empty());
142 Ok::<_, Infallible>(resp)
143 }
144
145 let make_handle_request = |_| {
146 |peer_id: &[u8], req| {
147 let peer_id = peer_id.to_vec();
148 async move { handle_request(&peer_id, req).await }
149 }
150 };
151
152 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
153 accept_and_serve_http(listener, make_responder, make_handle_request, None)
154 .await
155 .unwrap();
156 }
157
158 #[ignore]
159 #[tokio::test]
160 async fn compiles_with_mutable_peers() {
161 let peers = Arc::new(Mutex::new(HashMap::<Vec<u8>, [u8; 32]>::from([
162 (Vec::from(b"alice"), [0u8; 32]),
163 (Vec::from(b"bob"), [1u8; 32]),
164 (Vec::from(b"charlie"), [2u8; 32]),
165 ])));
166
167 let make_responder = move |_| {
168 let peers = peers.clone();
169 Responder::new(move |id| peers.lock().unwrap().get(id).cloned())
170 };
171
172 let make_handle_request = |_| {
173 |peer_id: &[u8], _req: Request<Body>| async move {
174 let _ = peer_id;
175 let resp = Response::new(Body::empty());
176 Ok::<_, Infallible>(resp)
177 }
178 };
179
180 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
181 accept_and_serve_http(listener, make_responder, make_handle_request, None)
182 .await
183 .unwrap();
184 }
185}