1use super::{pool::ExtractSocketAddr, Client, ExclusiveBody};
2use crate::{
3 error::{FaucetError, FaucetResult},
4 global_conn::{add_connection, remove_connection},
5};
6use base64::Engine;
7use hyper::{
8 header::UPGRADE,
9 http::{uri::PathAndQuery, HeaderValue},
10 upgrade::Upgraded,
11 HeaderMap, Request, Response, StatusCode, Uri,
12};
13use hyper_util::rt::TokioIo;
14use sha1::{Digest, Sha1};
15use std::net::SocketAddr;
16
17struct UpgradeInfo {
18 headers: HeaderMap,
19 uri: Uri,
20}
21
22impl UpgradeInfo {
23 fn new<ReqBody>(req: &Request<ReqBody>, socket_addr: SocketAddr) -> FaucetResult<Self> {
24 let headers = req.headers().clone();
25 let uri = build_uri(socket_addr, req.uri().path_and_query())?;
26 Ok(Self { headers, uri })
27 }
28}
29
30const SEC_WEBSOCKET_APPEND: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
31const SEC_WEBSOCKET_KEY: &str = "Sec-WebSocket-Key";
32const SEC_WEBSOCKET_ACCEPT: &str = "Sec-WebSocket-Accept";
33
34fn calculate_sec_websocket_accept<'buffer>(key: &[u8], buffer: &'buffer mut [u8]) -> &'buffer [u8] {
35 let mut hasher = Sha1::new();
36 hasher.update(key);
37 hasher.update(SEC_WEBSOCKET_APPEND);
38 let len = base64::engine::general_purpose::STANDARD
39 .encode_slice(hasher.finalize(), buffer)
40 .expect("Should always write the internal buffer");
41 &buffer[..len]
42}
43
44fn build_uri(socket_addr: SocketAddr, path: Option<&PathAndQuery>) -> FaucetResult<Uri> {
45 let mut uri_builder = Uri::builder()
46 .scheme("ws")
47 .authority(socket_addr.to_string());
48 match path {
49 Some(path) => uri_builder = uri_builder.path_and_query(path.clone()),
50 None => uri_builder = uri_builder.path_and_query("/"),
51 }
52 Ok(uri_builder.build()?)
53}
54
55async fn server_upgraded_io(upgraded: Upgraded, mut upgrade_info: UpgradeInfo) -> FaucetResult<()> {
56 let mut upgraded = TokioIo::new(upgraded);
57 let mut request = Request::builder().uri(upgrade_info.uri).body(())?;
60 std::mem::swap(request.headers_mut(), &mut upgrade_info.headers);
61 let (mut ws_tx, _) = tokio_tungstenite::connect_async(request).await?;
62
63 tokio::io::copy_bidirectional(&mut upgraded, ws_tx.get_mut()).await?;
65
66 Ok(())
67}
68
69pub enum UpgradeStatus<ReqBody> {
70 Upgraded(Response<ExclusiveBody>),
71 NotUpgraded(Request<ReqBody>),
72}
73
74async fn upgrade_connection_from_request<ReqBody>(
75 mut req: Request<ReqBody>,
76 client: impl ExtractSocketAddr,
77) -> FaucetResult<()> {
78 let upgrade_info = UpgradeInfo::new(&req, client.socket_addr())?;
79 let upgraded = hyper::upgrade::on(&mut req).await?;
80 server_upgraded_io(upgraded, upgrade_info).await?;
81 Ok(())
82}
83
84async fn init_upgrade<ReqBody: Send + Sync + 'static>(
85 req: Request<ReqBody>,
86 client: impl ExtractSocketAddr + Send + Sync + 'static,
87) -> FaucetResult<Response<ExclusiveBody>> {
88 let mut res = Response::new(ExclusiveBody::empty());
89 let sec_websocket_key = req
90 .headers()
91 .get(SEC_WEBSOCKET_KEY)
92 .cloned()
93 .ok_or(FaucetError::no_sec_web_socket_key())?;
94 tokio::task::spawn(async move {
95 add_connection();
96 if let Err(e) = upgrade_connection_from_request(req, client).await {
97 log::error!("upgrade error: {:?}", e);
98 }
99 remove_connection();
100 });
101 *res.status_mut() = StatusCode::SWITCHING_PROTOCOLS;
102 res.headers_mut()
103 .insert(UPGRADE, HeaderValue::from_static("websocket"));
104 res.headers_mut().insert(
105 hyper::header::CONNECTION,
106 HeaderValue::from_static("Upgrade"),
107 );
108 let mut buffer = [0u8; 32];
109 res.headers_mut().insert(
110 SEC_WEBSOCKET_ACCEPT,
111 HeaderValue::from_bytes(calculate_sec_websocket_accept(
112 sec_websocket_key.as_bytes(),
113 &mut buffer,
114 ))?,
115 );
116 Ok(res)
117}
118
119#[inline(always)]
120async fn attempt_upgrade<ReqBody: Send + Sync + 'static>(
121 req: Request<ReqBody>,
122 client: impl ExtractSocketAddr + Send + Sync + 'static,
123) -> FaucetResult<UpgradeStatus<ReqBody>> {
124 if req.headers().contains_key(UPGRADE) {
125 return Ok(UpgradeStatus::Upgraded(init_upgrade(req, client).await?));
126 }
127 Ok(UpgradeStatus::NotUpgraded(req))
128}
129
130impl Client {
131 pub async fn attempt_upgrade<ReqBody>(
132 &self,
133 req: Request<ReqBody>,
134 ) -> FaucetResult<UpgradeStatus<ReqBody>>
135 where
136 ReqBody: Send + Sync + 'static,
137 {
138 attempt_upgrade(req, self.clone()).await
139 }
140}
141
142#[cfg(test)]
143mod tests {
144 use crate::networking::get_available_socket;
145
146 use super::*;
147
148 #[test]
149 fn test_calculate_sec_websocket_accept() {
150 let key = "dGhlIHNhbXBsZSBub25jZQ==";
151 let mut buffer = [0u8; 32];
152 let accept = calculate_sec_websocket_accept(key.as_bytes(), &mut buffer);
153 assert_eq!(accept, b"s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
154 }
155
156 #[test]
157 fn test_build_uri() {
158 let socket_addr = "127.0.0.1:8000".parse().unwrap();
159 let path_and_query = "/websocket".parse().unwrap();
160 let path = Some(&path_and_query);
161 let result = build_uri(socket_addr, path).unwrap();
162 assert_eq!(result, "ws://127.0.0.1:8000/websocket");
163 }
164
165 #[test]
166 fn build_uri_no_path() {
167 let socket_addr = "127.0.0.1:8000".parse().unwrap();
168 let path = None;
169 let result = build_uri(socket_addr, path).unwrap();
170 assert_eq!(result, "ws://127.0.0.1:8000");
171 }
172
173 #[tokio::test]
174 async fn test_init_upgrade_from_request() {
175 struct MockClient {
176 socket_addr: SocketAddr,
177 }
178
179 impl ExtractSocketAddr for MockClient {
180 fn socket_addr(&self) -> SocketAddr {
181 self.socket_addr
182 }
183 }
184
185 let socket_addr = get_available_socket(20).await.unwrap();
186
187 let client = MockClient { socket_addr };
188
189 let server = tokio::spawn(async move {
190 dummy_websocket_server::run(socket_addr).await.unwrap();
191 });
192
193 let uri = Uri::builder()
194 .scheme("http")
195 .authority(socket_addr.to_string().as_str())
196 .path_and_query("/")
197 .build()
198 .unwrap();
199
200 let req = Request::builder()
201 .uri(uri)
202 .header(UPGRADE, "websocket")
203 .header("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
204 .body(())
205 .unwrap();
206
207 let result = init_upgrade(req, client).await.unwrap();
208
209 server.abort();
210
211 assert_eq!(result.status(), StatusCode::SWITCHING_PROTOCOLS);
212 assert_eq!(
213 result.headers().get(UPGRADE).unwrap(),
214 HeaderValue::from_static("websocket")
215 );
216 assert_eq!(
217 result.headers().get(SEC_WEBSOCKET_ACCEPT).unwrap(),
218 HeaderValue::from_static("s3pPLMBiTxaQ9kYGzzhZRbK+xOo=")
219 );
220 assert_eq!(
221 result.headers().get(hyper::header::CONNECTION).unwrap(),
222 HeaderValue::from_static("Upgrade")
223 );
224 }
225
226 #[tokio::test]
227 async fn test_init_upgrade_from_request_no_sec_key() {
228 struct MockClient {
229 socket_addr: SocketAddr,
230 }
231
232 impl ExtractSocketAddr for MockClient {
233 fn socket_addr(&self) -> SocketAddr {
234 self.socket_addr
235 }
236 }
237
238 let socket_addr = get_available_socket(20).await.unwrap();
239
240 let client = MockClient { socket_addr };
241
242 let server = tokio::spawn(async move {
243 dummy_websocket_server::run(socket_addr).await.unwrap();
244 });
245
246 let uri = Uri::builder()
247 .scheme("http")
248 .authority(socket_addr.to_string().as_str())
249 .path_and_query("/")
250 .build()
251 .unwrap();
252
253 let req = Request::builder()
254 .uri(uri)
255 .header(UPGRADE, "websocket")
256 .body(())
257 .unwrap();
258
259 let result = init_upgrade(req, client).await;
260
261 server.abort();
262
263 assert!(result.is_err());
264 }
265
266 #[tokio::test]
267 async fn test_attempt_upgrade_no_upgrade_header() {
268 struct MockClient {
269 socket_addr: SocketAddr,
270 }
271
272 impl ExtractSocketAddr for MockClient {
273 fn socket_addr(&self) -> SocketAddr {
274 self.socket_addr
275 }
276 }
277
278 let socket_addr = get_available_socket(20).await.unwrap();
279
280 let client = MockClient { socket_addr };
281
282 let server = tokio::spawn(async move {
283 dummy_websocket_server::run(socket_addr).await.unwrap();
284 });
285
286 let uri = Uri::builder()
287 .scheme("http")
288 .authority(socket_addr.to_string().as_str())
289 .path_and_query("/")
290 .build()
291 .unwrap();
292
293 let req = Request::builder()
294 .uri(uri)
295 .header("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
296 .body(())
297 .unwrap();
298
299 let result = attempt_upgrade(req, client).await.unwrap();
300
301 server.abort();
302
303 match result {
304 UpgradeStatus::NotUpgraded(_) => {}
305 _ => panic!("Expected NotUpgraded"),
306 }
307 }
308
309 #[tokio::test]
310 async fn test_attempt_upgrade_with_upgrade_header() {
311 struct MockClient {
312 socket_addr: SocketAddr,
313 }
314
315 impl ExtractSocketAddr for MockClient {
316 fn socket_addr(&self) -> SocketAddr {
317 self.socket_addr
318 }
319 }
320
321 let socket_addr = get_available_socket(20).await.unwrap();
322
323 let client = MockClient { socket_addr };
324
325 let server = tokio::spawn(async move {
326 dummy_websocket_server::run(socket_addr).await.unwrap();
327 });
328
329 let uri = Uri::builder()
330 .scheme("http")
331 .authority(socket_addr.to_string().as_str())
332 .path_and_query("/")
333 .build()
334 .unwrap();
335
336 let req = Request::builder()
337 .uri(uri)
338 .header("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
339 .header(UPGRADE, "websocket")
340 .body(())
341 .unwrap();
342
343 let result = attempt_upgrade(req, client).await.unwrap();
344
345 server.abort();
346
347 match result {
348 UpgradeStatus::Upgraded(res) => {
349 assert_eq!(res.status(), StatusCode::SWITCHING_PROTOCOLS);
350 assert_eq!(
351 res.headers().get(UPGRADE).unwrap(),
352 HeaderValue::from_static("websocket")
353 );
354 assert_eq!(
355 res.headers().get(SEC_WEBSOCKET_ACCEPT).unwrap(),
356 HeaderValue::from_static("s3pPLMBiTxaQ9kYGzzhZRbK+xOo=")
357 );
358 assert_eq!(
359 res.headers().get(hyper::header::CONNECTION).unwrap(),
360 HeaderValue::from_static("Upgrade")
361 );
362 }
363 _ => panic!("Expected NotUpgraded"),
364 }
365 }
366
367 #[tokio::test]
368 async fn test_upgrade_connection_from_request() {
369 struct MockClient {
370 socket_addr: SocketAddr,
371 }
372
373 impl ExtractSocketAddr for MockClient {
374 fn socket_addr(&self) -> SocketAddr {
375 self.socket_addr
376 }
377 }
378
379 let socket_addr = get_available_socket(20).await.unwrap();
380
381 let client = MockClient { socket_addr };
382
383 let server = tokio::spawn(async move {
384 dummy_websocket_server::run(socket_addr).await.unwrap();
385 });
386
387 let uri = Uri::builder()
388 .scheme("http")
389 .authority(socket_addr.to_string().as_str())
390 .path_and_query("/")
391 .build()
392 .unwrap();
393
394 let req = Request::builder()
395 .uri(uri)
396 .header(UPGRADE, "websocket")
397 .header("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
398 .body(())
399 .unwrap();
400
401 let _ = tokio::spawn(async move {
402 let result = upgrade_connection_from_request(req, client).await;
403 assert!(result.is_ok());
404 })
405 .await;
406
407 server.abort();
408 }
409
410 mod dummy_websocket_server {
411 use std::{io::Error, net::SocketAddr};
412
413 use futures_util::{future, StreamExt, TryStreamExt};
414 use log::info;
415 use tokio::net::{TcpListener, TcpStream};
416
417 pub async fn run(addr: SocketAddr) -> Result<(), Error> {
418 let try_socket = TcpListener::bind(&addr).await;
420 let listener = try_socket.expect("Failed to bind");
421 info!("Listening on: {}", addr);
422
423 while let Ok((stream, _)) = listener.accept().await {
424 tokio::spawn(accept_connection(stream));
425 }
426
427 Ok(())
428 }
429
430 async fn accept_connection(stream: TcpStream) {
431 let addr = stream
432 .peer_addr()
433 .expect("connected streams should have a peer address");
434 info!("Peer address: {}", addr);
435
436 let ws_stream = tokio_tungstenite::accept_async(stream)
437 .await
438 .expect("Error during the websocket handshake occurred");
439
440 info!("New WebSocket connection: {}", addr);
441
442 let (write, read) = ws_stream.split();
443 read.try_filter(|msg| future::ready(msg.is_text() || msg.is_binary()))
445 .forward(write)
446 .await
447 .expect("Failed to forward messages")
448 }
449 }
450}