1use base64::encode;
5use std::future::Future;
6
7use hyper::{
8 header::{self, HeaderValue},
9 http,
10 upgrade::Upgraded,
11 Body, Response, StatusCode,
12};
13use sha1::Digest;
14use tokio::task::JoinError;
15use tokio_tungstenite::tungstenite::protocol::Role;
16
17use anyhow::Result;
18use log::*;
19
20pub type WsStream = tokio_tungstenite::WebSocketStream<Upgraded>;
21
22const WS_MAGIC_UUID: &'static str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
23
24fn convert_client_key(key: &str) -> String {
25 let to_hash = format!("{}{}", key, WS_MAGIC_UUID);
26
27 let hash = sha1::Sha1::digest(to_hash.as_bytes());
28 encode(&hash)
29}
30
31async fn convert_to_ws_stream(upgraded: Upgraded) -> WsStream {
32 tokio_tungstenite::WebSocketStream::from_partially_read(
33 upgraded,
34 Vec::new(),
35 Role::Server,
36 None,
37 )
38 .await
39}
40
41fn upgrade_connection(
42 mut req: hyper::Request<hyper::Body>,
43) -> impl Future<Output = Result<hyper::Result<WsStream>, JoinError>> {
44 tokio::task::spawn(async move {
45 match hyper::upgrade::on(&mut req).await {
46 Ok(upgraded) => Ok(convert_to_ws_stream(upgraded).await),
47 Err(e) => Err(e),
48 }
49 })
50}
51
52pub fn create_ws(
55 req: hyper::Request<hyper::Body>,
56) -> http::Result<(
57 hyper::Response<hyper::Body>,
58 Option<impl Future<Output = Result<hyper::Result<WsStream>, JoinError>>>,
59)> {
60 debug!("request headers: {:?}", req.headers());
61
62 let mut res = Response::new(Body::empty());
63
64 if req.method() != hyper::Method::GET {
66 *res.status_mut() = StatusCode::BAD_REQUEST;
67 }
68
69 if req.version() < hyper::Version::HTTP_11 {
71 *res.status_mut() = StatusCode::BAD_REQUEST;
72 }
73
74 if let Some(header_value) = req.headers().get(header::CONNECTION) {
76 if let Ok(value) = header_value.to_str() {
77 if !value
78 .split(",")
79 .map(|s| s.trim())
80 .any(|s| s.eq_ignore_ascii_case("upgrade"))
81 {
82 *res.status_mut() = StatusCode::BAD_REQUEST;
83 }
84 }
85 } else {
86 *res.status_mut() = StatusCode::BAD_REQUEST;
87 }
88
89 if let Some(header_value) = req.headers().get(header::UPGRADE) {
91 if let Ok(value) = header_value.to_str() {
92 if !value.eq_ignore_ascii_case("websocket") {
93 *res.status_mut() = StatusCode::BAD_REQUEST;
94 }
95 }
96 } else {
97 *res.status_mut() = StatusCode::BAD_REQUEST;
98 }
99
100 if res.status() == StatusCode::BAD_REQUEST {
102 return Ok((res, None));
103 }
104
105 if let Some(socket_key) = req.headers().get(header::SEC_WEBSOCKET_KEY) {
106 if let Ok(socket_value) = socket_key.to_str() {
107 trace!("socket key: {:?}", socket_value);
108
109 *res.status_mut() = StatusCode::SWITCHING_PROTOCOLS;
110 res.headers_mut()
111 .insert(header::UPGRADE, HeaderValue::from_static("websocket"));
112 res.headers_mut()
113 .insert(header::CONNECTION, HeaderValue::from_static("upgrade"));
114 res.headers_mut().insert(
115 header::SEC_WEBSOCKET_ACCEPT,
116 HeaderValue::from_str(&convert_client_key(&socket_value))?,
117 );
118
119 return Ok((res, Some(upgrade_connection(req))));
120 }
121 }
122
123 *res.status_mut() = StatusCode::BAD_REQUEST;
124 Ok((res, None))
125}
126
127#[cfg(test)]
128mod tests {
129 use super::*;
130
131 use std::net::SocketAddr;
132
133 use futures::{SinkExt, StreamExt};
134 use http::request::Builder;
135 use hyper::{
136 service::{make_service_fn, service_fn},
137 Client, Method, Request, Version,
138 };
139 use tokio_tungstenite::{connect_async, tungstenite::Message};
140
141 async fn ws_listener(req: Request<Body>) -> http::Result<Response<Body>> {
143 trace!("{:?}", req);
144
145 let (res, ws_fut) = match create_ws(req) {
146 Ok(t) => t,
147 Err(e) => {
148 error!("error creating ws stream: {:?}", e);
149
150 let mut res = Response::new(Body::empty());
151 *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
152 return Ok(res);
153 }
154 };
155
156 if let Some(ws_fut) = ws_fut {
157 tokio::task::spawn(async move {
158 if let Ok(Ok(mut stream)) = ws_fut.await {
159 while let Some(Ok(message)) = stream.next().await {
160 debug!("server rx: {:?}", message);
161 }
162 }
163 });
164 }
165
166 Ok(res)
167 }
168
169 fn create_server() -> SocketAddr {
170 let addr = ([127, 0, 0, 1], 0).into();
171 let make_service =
172 make_service_fn(|_| async { Ok::<_, hyper::Error>(service_fn(ws_listener)) });
173
174 let server = hyper::Server::bind(&addr).serve(make_service);
175
176 let addr = server.local_addr();
178 debug!("Listening on: {:?}", addr);
179
180 tokio::task::spawn(async move {
181 if let Err(e) = server.await {
182 eprintln!("server error: {}", e);
183 }
184 });
185
186 addr
187 }
188
189 #[tokio::test]
190 async fn roundtrip_ping() {
191 let server_addr = create_server();
192
193 let (stream, res) = connect_async(format!("ws://{}", server_addr))
194 .await
195 .unwrap();
196
197 assert_eq!(res.status(), StatusCode::SWITCHING_PROTOCOLS);
198
199 let (mut write, mut read) = stream.split();
200
201 let data = vec![1, 2, 3, 4, 5];
202 let data_c = data.clone();
203
204 tokio::task::spawn(async move { write.send(Message::Ping(data_c)).await });
205 let pong = read.next().await.unwrap().unwrap();
206
207 assert_eq!(Message::Pong(data), pong);
208 }
209
210 fn valid_request(server_addr: SocketAddr) -> Builder {
211 Request::builder()
212 .method(Method::GET)
213 .uri(format!("http://{}", server_addr))
214 .version(Version::HTTP_11)
215 .header(header::CONNECTION, HeaderValue::from_static("upgrade"))
216 .header(header::UPGRADE, HeaderValue::from_static("websocket"))
217 .header(
218 header::SEC_WEBSOCKET_KEY,
219 HeaderValue::from_static("123456"),
220 )
221 .header(
222 header::SEC_WEBSOCKET_VERSION,
223 HeaderValue::from_static("13"),
224 )
225 }
226
227 #[tokio::test]
228 async fn invalid_request() {
229 let _ = env_logger::try_init();
230
231 let server_addr = create_server();
232 let client = Client::new();
233
234 let invalid_method = valid_request(server_addr)
235 .method(Method::PUT)
236 .body(Body::empty())
237 .unwrap();
238
239 let resp = client.request(invalid_method).await.unwrap();
240 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
241
242 let invalid_version = valid_request(server_addr)
243 .version(Version::HTTP_10)
244 .body(Body::empty())
245 .unwrap();
246
247 let resp = client.request(invalid_version).await.unwrap();
248 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
249
250 let mut no_connection_header = valid_request(server_addr);
251 no_connection_header
252 .headers_mut()
253 .unwrap()
254 .remove(header::CONNECTION);
255 let no_connection_header = no_connection_header.body(Body::empty()).unwrap();
256
257 let resp = client.request(no_connection_header).await.unwrap();
258 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
259
260 let mut no_upgrade_header = valid_request(server_addr);
261 no_upgrade_header
262 .headers_mut()
263 .unwrap()
264 .remove(header::UPGRADE);
265 let no_upgrade_header = no_upgrade_header.body(Body::empty()).unwrap();
266
267 let resp = client.request(no_upgrade_header).await.unwrap();
268 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
269
270 let mut no_key_header = valid_request(server_addr);
271 no_key_header
272 .headers_mut()
273 .unwrap()
274 .remove(header::SEC_WEBSOCKET_KEY);
275 let no_key_header = no_key_header.body(Body::empty()).unwrap();
276
277 let resp = client.request(no_key_header).await.unwrap();
278 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
279 }
280
281 #[tokio::test]
284 async fn valid_key_hash() {
285 let server_addr = create_server();
286 let client = Client::new();
287
288 let mut request = valid_request(server_addr);
289
290 request
291 .headers_mut()
292 .unwrap()
293 .remove(header::SEC_WEBSOCKET_KEY);
294
295 let request = request
296 .header(
297 header::SEC_WEBSOCKET_KEY,
298 HeaderValue::from_static("dGhlIHNhbXBsZSBub25jZQ=="),
299 )
300 .body(Body::empty())
301 .unwrap();
302
303 let resp = client.request(request).await.unwrap();
304
305 let accept_key = &resp.headers()[header::SEC_WEBSOCKET_ACCEPT];
306
307 assert_eq!(accept_key, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=")
308 }
309}