hyper_ws_listener/
lib.rs

1// Note: `hyper::upgrade` docs link to this upgrade.
2// use std::str};
3
4use base64::encode;
5use std::future::Future;
6
7use hyper::{
8    header::{self, HeaderValue},
9    http,
10    upgrade::Upgraded,
11    Body, Response, StatusCode,
12};
13use sha1::Digest;
14use tokio::task::JoinError;
15use tokio_tungstenite::tungstenite::protocol::Role;
16
17use anyhow::Result;
18use log::*;
19
20pub type WsStream = tokio_tungstenite::WebSocketStream<Upgraded>;
21
22const WS_MAGIC_UUID: &'static str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
23
24fn convert_client_key(key: &str) -> String {
25    let to_hash = format!("{}{}", key, WS_MAGIC_UUID);
26
27    let hash = sha1::Sha1::digest(to_hash.as_bytes());
28    encode(&hash)
29}
30
31async fn convert_to_ws_stream(upgraded: Upgraded) -> WsStream {
32    tokio_tungstenite::WebSocketStream::from_partially_read(
33        upgraded,
34        Vec::new(),
35        Role::Server,
36        None,
37    )
38    .await
39}
40
41fn upgrade_connection(
42    mut req: hyper::Request<hyper::Body>,
43) -> impl Future<Output = Result<hyper::Result<WsStream>, JoinError>> {
44    tokio::task::spawn(async move {
45        match hyper::upgrade::on(&mut req).await {
46            Ok(upgraded) => Ok(convert_to_ws_stream(upgraded).await),
47            Err(e) => Err(e),
48        }
49    })
50}
51
52/// Handle a WS handshake and create a tokio_tungstenite stream
53/// Based off of the [Mozilla docs](https://developer.mozilla.org/en-US/docs/Web/API/WebSockets_API/Writing_WebSocket_servers#the_websocket_handshake) on WebSocket servers.
54pub fn create_ws(
55    req: hyper::Request<hyper::Body>,
56) -> http::Result<(
57    hyper::Response<hyper::Body>,
58    Option<impl Future<Output = Result<hyper::Result<WsStream>, JoinError>>>,
59)> {
60    debug!("request headers: {:?}", req.headers());
61
62    let mut res = Response::new(Body::empty());
63
64    // The method must be a GET request:
65    if req.method() != hyper::Method::GET {
66        *res.status_mut() = StatusCode::BAD_REQUEST;
67    }
68
69    // Version must be at least HTTP 1.1
70    if req.version() < hyper::Version::HTTP_11 {
71        *res.status_mut() = StatusCode::BAD_REQUEST;
72    }
73
74    // `Connection: upgrade` header must valid and present
75    if let Some(header_value) = req.headers().get(header::CONNECTION) {
76        if let Ok(value) = header_value.to_str() {
77            if !value
78                .split(",")
79                .map(|s| s.trim())
80                .any(|s| s.eq_ignore_ascii_case("upgrade"))
81            {
82                *res.status_mut() = StatusCode::BAD_REQUEST;
83            }
84        }
85    } else {
86        *res.status_mut() = StatusCode::BAD_REQUEST;
87    }
88
89    // `Upgrade: websocket` header must valid and present
90    if let Some(header_value) = req.headers().get(header::UPGRADE) {
91        if let Ok(value) = header_value.to_str() {
92            if !value.eq_ignore_ascii_case("websocket") {
93                *res.status_mut() = StatusCode::BAD_REQUEST;
94            }
95        }
96    } else {
97        *res.status_mut() = StatusCode::BAD_REQUEST;
98    }
99
100    // Fail before we attempt to upgrade the connection.
101    if res.status() == StatusCode::BAD_REQUEST {
102        return Ok((res, None));
103    }
104
105    if let Some(socket_key) = req.headers().get(header::SEC_WEBSOCKET_KEY) {
106        if let Ok(socket_value) = socket_key.to_str() {
107            trace!("socket key: {:?}", socket_value);
108
109            *res.status_mut() = StatusCode::SWITCHING_PROTOCOLS;
110            res.headers_mut()
111                .insert(header::UPGRADE, HeaderValue::from_static("websocket"));
112            res.headers_mut()
113                .insert(header::CONNECTION, HeaderValue::from_static("upgrade"));
114            res.headers_mut().insert(
115                header::SEC_WEBSOCKET_ACCEPT,
116                HeaderValue::from_str(&convert_client_key(&socket_value))?,
117            );
118
119            return Ok((res, Some(upgrade_connection(req))));
120        }
121    }
122
123    *res.status_mut() = StatusCode::BAD_REQUEST;
124    Ok((res, None))
125}
126
127#[cfg(test)]
128mod tests {
129    use super::*;
130
131    use std::net::SocketAddr;
132
133    use futures::{SinkExt, StreamExt};
134    use http::request::Builder;
135    use hyper::{
136        service::{make_service_fn, service_fn},
137        Client, Method, Request, Version,
138    };
139    use tokio_tungstenite::{connect_async, tungstenite::Message};
140
141    /// Our server HTTP handler to initiate HTTP upgrades.
142    async fn ws_listener(req: Request<Body>) -> http::Result<Response<Body>> {
143        trace!("{:?}", req);
144
145        let (res, ws_fut) = match create_ws(req) {
146            Ok(t) => t,
147            Err(e) => {
148                error!("error creating ws stream: {:?}", e);
149
150                let mut res = Response::new(Body::empty());
151                *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
152                return Ok(res);
153            }
154        };
155
156        if let Some(ws_fut) = ws_fut {
157            tokio::task::spawn(async move {
158                if let Ok(Ok(mut stream)) = ws_fut.await {
159                    while let Some(Ok(message)) = stream.next().await {
160                        debug!("server rx: {:?}", message);
161                    }
162                }
163            });
164        }
165
166        Ok(res)
167    }
168
169    fn create_server() -> SocketAddr {
170        let addr = ([127, 0, 0, 1], 0).into();
171        let make_service =
172            make_service_fn(|_| async { Ok::<_, hyper::Error>(service_fn(ws_listener)) });
173
174        let server = hyper::Server::bind(&addr).serve(make_service);
175
176        // We need the assigned address for the client to send it messages.
177        let addr = server.local_addr();
178        debug!("Listening on: {:?}", addr);
179
180        tokio::task::spawn(async move {
181            if let Err(e) = server.await {
182                eprintln!("server error: {}", e);
183            }
184        });
185
186        addr
187    }
188
189    #[tokio::test]
190    async fn roundtrip_ping() {
191        let server_addr = create_server();
192
193        let (stream, res) = connect_async(format!("ws://{}", server_addr))
194            .await
195            .unwrap();
196
197        assert_eq!(res.status(), StatusCode::SWITCHING_PROTOCOLS);
198
199        let (mut write, mut read) = stream.split();
200
201        let data = vec![1, 2, 3, 4, 5];
202        let data_c = data.clone();
203
204        tokio::task::spawn(async move { write.send(Message::Ping(data_c)).await });
205        let pong = read.next().await.unwrap().unwrap();
206
207        assert_eq!(Message::Pong(data), pong);
208    }
209
210    fn valid_request(server_addr: SocketAddr) -> Builder {
211        Request::builder()
212            .method(Method::GET)
213            .uri(format!("http://{}", server_addr))
214            .version(Version::HTTP_11)
215            .header(header::CONNECTION, HeaderValue::from_static("upgrade"))
216            .header(header::UPGRADE, HeaderValue::from_static("websocket"))
217            .header(
218                header::SEC_WEBSOCKET_KEY,
219                HeaderValue::from_static("123456"),
220            )
221            .header(
222                header::SEC_WEBSOCKET_VERSION,
223                HeaderValue::from_static("13"),
224            )
225    }
226
227    #[tokio::test]
228    async fn invalid_request() {
229        let _ = env_logger::try_init();
230
231        let server_addr = create_server();
232        let client = Client::new();
233
234        let invalid_method = valid_request(server_addr)
235            .method(Method::PUT)
236            .body(Body::empty())
237            .unwrap();
238
239        let resp = client.request(invalid_method).await.unwrap();
240        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
241
242        let invalid_version = valid_request(server_addr)
243            .version(Version::HTTP_10)
244            .body(Body::empty())
245            .unwrap();
246
247        let resp = client.request(invalid_version).await.unwrap();
248        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
249
250        let mut no_connection_header = valid_request(server_addr);
251        no_connection_header
252            .headers_mut()
253            .unwrap()
254            .remove(header::CONNECTION);
255        let no_connection_header = no_connection_header.body(Body::empty()).unwrap();
256
257        let resp = client.request(no_connection_header).await.unwrap();
258        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
259
260        let mut no_upgrade_header = valid_request(server_addr);
261        no_upgrade_header
262            .headers_mut()
263            .unwrap()
264            .remove(header::UPGRADE);
265        let no_upgrade_header = no_upgrade_header.body(Body::empty()).unwrap();
266
267        let resp = client.request(no_upgrade_header).await.unwrap();
268        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
269
270        let mut no_key_header = valid_request(server_addr);
271        no_key_header
272            .headers_mut()
273            .unwrap()
274            .remove(header::SEC_WEBSOCKET_KEY);
275        let no_key_header = no_key_header.body(Body::empty()).unwrap();
276
277        let resp = client.request(no_key_header).await.unwrap();
278        assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
279    }
280
281    // Request and Response key values take from Mozilla's article on
282    // implementing a
283    #[tokio::test]
284    async fn valid_key_hash() {
285        let server_addr = create_server();
286        let client = Client::new();
287
288        let mut request = valid_request(server_addr);
289
290        request
291            .headers_mut()
292            .unwrap()
293            .remove(header::SEC_WEBSOCKET_KEY);
294
295        let request = request
296            .header(
297                header::SEC_WEBSOCKET_KEY,
298                HeaderValue::from_static("dGhlIHNhbXBsZSBub25jZQ=="),
299            )
300            .body(Body::empty())
301            .unwrap();
302
303        let resp = client.request(request).await.unwrap();
304
305        let accept_key = &resp.headers()[header::SEC_WEBSOCKET_ACCEPT];
306
307        assert_eq!(accept_key, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=")
308    }
309}