1use super::{pool::ExtractSocketAddr, Client, ExclusiveBody};
2use crate::{
3 error::{BadRequestReason, FaucetError, FaucetResult},
4 global_conn::{add_connection, remove_connection},
5 server::logging::{EventLogData, FaucetTracingLevel},
6 shutdown::ShutdownSignal,
7 telemetry::send_log_event,
8};
9use base64::Engine;
10use bytes::Bytes;
11use futures_util::StreamExt;
12use hyper::{
13 header::UPGRADE,
14 http::{uri::PathAndQuery, HeaderValue},
15 upgrade::Upgraded,
16 HeaderMap, Request, Response, StatusCode, Uri,
17};
18use hyper_util::rt::TokioIo;
19use serde_json::json;
20use sha1::{Digest, Sha1};
21use std::{
22 collections::HashMap, future::Future, net::SocketAddr, str::FromStr, sync::LazyLock,
23 time::Duration,
24};
25use tokio::sync::Mutex;
26use tokio_tungstenite::tungstenite::{
27 protocol::{frame::coding::CloseCode, CloseFrame, WebSocketConfig},
28 Message, Utf8Bytes,
29};
30use uuid::Uuid;
31
32struct UpgradeInfo {
33 headers: HeaderMap,
34 uri: Uri,
35}
36
37impl UpgradeInfo {
38 fn new<ReqBody>(req: &Request<ReqBody>, socket_addr: SocketAddr) -> FaucetResult<Self> {
39 let headers = req.headers().clone();
40 let uri = build_uri(socket_addr, req.uri().path_and_query())?;
41 Ok(Self { headers, uri })
42 }
43}
44
45const SEC_WEBSOCKET_APPEND: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
46const SEC_WEBSOCKET_KEY: &str = "Sec-WebSocket-Key";
47const SEC_WEBSOCKET_ACCEPT: &str = "Sec-WebSocket-Accept";
48
49fn calculate_sec_websocket_accept<'buffer>(key: &[u8], buffer: &'buffer mut [u8]) -> &'buffer [u8] {
50 let mut hasher = Sha1::new();
51 hasher.update(key);
52 hasher.update(SEC_WEBSOCKET_APPEND);
53 let len = base64::engine::general_purpose::STANDARD
54 .encode_slice(hasher.finalize(), buffer)
55 .expect("Should always write the internal buffer");
56 &buffer[..len]
57}
58
59fn build_uri(socket_addr: SocketAddr, path: Option<&PathAndQuery>) -> FaucetResult<Uri> {
60 let mut uri_builder = Uri::builder()
61 .scheme("ws")
62 .authority(socket_addr.to_string());
63 match path {
64 Some(path) => uri_builder = uri_builder.path_and_query(path.clone()),
65 None => uri_builder = uri_builder.path_and_query("/"),
66 }
67 Ok(uri_builder.build()?)
68}
69
70use futures_util::SinkExt;
73
74type ConnectionPair = (
75 futures_util::stream::SplitSink<
76 tokio_tungstenite::WebSocketStream<
77 tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
78 >,
79 tokio_tungstenite::tungstenite::Message,
80 >,
81 futures_util::stream::SplitStream<
82 tokio_tungstenite::WebSocketStream<
83 tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
84 >,
85 >,
86);
87
88#[derive(Default)]
89struct ConnectionInstance {
90 purged: bool,
91 access_count: usize,
92 pair: Option<ConnectionPair>,
93}
94
95impl ConnectionInstance {
96 fn take(&mut self) -> ConnectionPair {
97 self.access_count += 1;
98 self.pair.take().unwrap()
99 }
100 fn put_back(&mut self, pair: ConnectionPair) {
101 self.access_count += 1;
102 self.pair = Some(pair);
103 }
104}
105
106struct ConnectionManagerInner {
107 map: HashMap<Uuid, ConnectionInstance>,
108 purge_count: usize,
109}
110
111struct ConnectionManager {
112 inner: Mutex<ConnectionManagerInner>,
113}
114
115impl ConnectionManager {
116 fn new() -> Self {
117 ConnectionManager {
118 inner: Mutex::new(ConnectionManagerInner {
119 map: HashMap::new(),
120 purge_count: 0,
121 }),
122 }
123 }
124 async fn initialize_if_not(
125 &self,
126 session_id: Uuid,
127 attempt: usize,
128 init: impl Future<Output = FaucetResult<ConnectionPair>>,
129 ) -> Option<FaucetResult<ConnectionPair>> {
130 {
131 let mut inner = self.inner.lock().await;
132 let entry = inner.map.entry(session_id).or_default();
133 if entry.access_count != 0 {
134 return None;
135 }
136 if entry.purged {
137 return Some(Err(FaucetError::WebSocketConnectionPurged));
138 }
139
140 if entry.access_count == 0 && attempt > 0 {
141 return Some(Err(FaucetError::WebSocketConnectionPurged));
142 }
143
144 entry.access_count += 1;
145 }
146 let connection_pair = match init.await {
147 Ok(connection_pair) => connection_pair,
148 Err(e) => return Some(Err(e)),
149 };
150 Some(Ok(connection_pair))
151 }
152 async fn attempt_take(&self, session_id: Uuid) -> FaucetResult<ConnectionPair> {
153 match self.inner.try_lock() {
154 Ok(mut inner) => {
155 let instance = inner.map.entry(session_id).or_default();
156
157 if instance.access_count % 2 == 0 {
158 return Ok(instance.take());
159 }
160
161 Err(FaucetError::WebSocketConnectionInUse)
162 }
163 _ => Err(FaucetError::WebSocketConnectionInUse),
164 }
165 }
166 async fn put_pack(&self, session_id: Uuid, pair: ConnectionPair) {
167 let mut inner = self.inner.lock().await;
168 if let Some(instance) = inner.map.get_mut(&session_id) {
169 instance.put_back(pair);
170 }
171 }
172 async fn remove_session(&self, session_id: Uuid) {
173 let mut inner = self.inner.lock().await;
174 inner.map.remove(&session_id);
175 inner.purge_count += 1;
176 if let Some(instance) = inner.map.get_mut(&session_id) {
177 instance.purged = true;
178 }
179 }
180}
181
182static SHINY_CONNECTION_CACHE: LazyLock<ConnectionManager> = LazyLock::new(ConnectionManager::new);
186
187async fn connect_to_worker(
188 mut upgrade_info: UpgradeInfo,
189 session_id: Uuid,
190 config: &'static WebSocketConfig,
191) -> FaucetResult<ConnectionPair> {
192 let mut request = Request::builder().uri(upgrade_info.uri).body(())?;
193 upgrade_info.headers.append(
194 "FAUCET_SESSION_ID",
195 HeaderValue::from_str(&session_id.to_string())
196 .expect("Unable to set Session ID as header. This is a bug. please report it!"),
197 );
198 *request.headers_mut() = upgrade_info.headers;
199 let (shiny_ws, _) =
200 tokio_tungstenite::connect_async_with_config(request, Some(*config), false).await?;
201 send_log_event(EventLogData {
202 target: "faucet".into(),
203 event_id: session_id,
204 parent_event_id: None,
205 level: FaucetTracingLevel::Info,
206 event_type: "websocket_connection".into(),
207 message: "Established new WebSocket connection to shiny".to_string(),
208 body: None,
209 });
210 Ok(shiny_ws.split())
211}
212
213async fn connect_or_retrieve(
214 upgrade_info: UpgradeInfo,
215 session_id: Uuid,
216 attempt: usize,
217 config: &'static WebSocketConfig,
218) -> FaucetResult<ConnectionPair> {
219 let init_pair = SHINY_CONNECTION_CACHE
220 .initialize_if_not(
221 session_id,
222 attempt,
223 connect_to_worker(upgrade_info, session_id, config),
224 )
225 .await;
226
227 match init_pair {
228 None => {
229 match SHINY_CONNECTION_CACHE.attempt_take(session_id).await {
232 Ok(con) => {
233 send_log_event(EventLogData {
234 target: "faucet".into(),
235 event_id: Uuid::new_v4(),
236 parent_event_id: Some(session_id),
237 event_type: "websocket_connection".into(),
238 level: FaucetTracingLevel::Info,
239 message: "Client successfully reconnected".to_string(),
240 body: Some(json!({"attempts": attempt})),
241 });
242 Ok(con)
243 }
244 Err(e) => FaucetResult::Err(e),
245 }
246 }
247 Some(init_pair_res) => init_pair_res,
248 }
249}
250
251const RECHECK_TIME: Duration = Duration::from_secs(60);
252const PING_INTERVAL: Duration = Duration::from_secs(1);
253const PING_INTERVAL_TIMEOUT: Duration = Duration::from_secs(30);
254const PING_BYTES: Bytes = Bytes::from_static(b"Ping");
255
256async fn server_upgraded_io(
257 upgraded: Upgraded,
258 upgrade_info: UpgradeInfo,
259 session_id: Uuid,
260 attempt: usize,
261 shutdown: &'static ShutdownSignal,
262 websocket_config: &'static WebSocketConfig,
263) -> FaucetResult<()> {
264 let upgraded = TokioIo::new(upgraded);
266 let upgraded_ws = tokio_tungstenite::WebSocketStream::from_raw_socket(
267 upgraded,
268 tokio_tungstenite::tungstenite::protocol::Role::Server,
269 Some(*websocket_config),
270 )
271 .await;
272 let (mut upgraded_tx, mut upgraded_rx) = upgraded_ws.split();
273
274 let (mut shiny_tx, mut shiny_rx) =
276 match connect_or_retrieve(upgrade_info, session_id, attempt, websocket_config).await {
277 Ok(pair) => pair,
278 Err(e) => match e {
279 FaucetError::WebSocketConnectionPurged => {
280 upgraded_tx
281 .send(Message::Close(Some(CloseFrame {
282 code: CloseCode::Normal,
283 reason: Utf8Bytes::from_static(
284 "Connection purged due to inactivity, update or error.",
285 ),
286 })))
287 .await?;
288 return Err(FaucetError::WebSocketConnectionPurged);
289 }
290 e => return Err(e),
291 },
292 };
293
294 let client_to_shiny = async {
297 loop {
298 log::debug!("Waiting for message or ping timeout");
299 tokio::select! {
300 msg = upgraded_rx.next() => {
301 log::debug!("Received msg: {msg:?}");
302 match msg {
303 Some(Ok(msg)) => {
304 if shiny_tx.send(msg).await.is_err() {
305 break; }
307 },
308 Some(Err(e)) => {
309 log::error!("Error sending websocket message to shiny: {e}");
310 break
311 }
312 _ => break
313 }
314 },
315 _ = tokio::time::sleep(PING_INTERVAL_TIMEOUT) => {
316 log::debug!("Ping timeout reached for session {session_id}");
317 break;
318 }
319 }
320 }
321 };
322
323 let shiny_to_client = async {
324 loop {
325 let ping_future = async {
326 tokio::time::sleep(PING_INTERVAL).await;
327 upgraded_tx.send(Message::Ping(PING_BYTES)).await
328 };
329 tokio::select! {
330 msg = shiny_rx.next() => {
331 match msg {
332 Some(Ok(msg)) => {
333 if upgraded_tx.send(msg).await.is_err() {
334 break; }
336 },
337 Some(Err(e)) => {
338 log::error!("Error sending websocket message to client: {e}");
339 break
340 }
341 _ => break
342 }
343 },
344 _ = ping_future => {}
345 }
346 }
347 };
348
349 tokio::select! {
351 _ = client_to_shiny => {
352 send_log_event(EventLogData {
353 target: "faucet".into(),
354 event_id: Uuid::new_v4(),
355 parent_event_id: Some(session_id),
356 event_type: "websocket_connection".into(),
357 level: FaucetTracingLevel::Info,
358 message: "Session ended by client.".to_string(),
359 body: None,
360 });
361 log::debug!("Client connection closed for session {session_id}.")
362 },
363 _ = shiny_to_client => {
364 SHINY_CONNECTION_CACHE.remove_session(session_id).await;
367 send_log_event(EventLogData {
368 target: "faucet".into(),
369 event_id: Uuid::new_v4(),
370 parent_event_id: Some(session_id),
371 event_type: "websocket_connection".into(),
372 level: FaucetTracingLevel::Info,
373 message: "Shiny session ended by Shiny.".to_string(),
374 body: None,
375 });
376 log::debug!("Shiny connection closed for session {session_id}.");
377 return Ok(());
378 },
379 _ = shutdown.wait() => {
380 log::debug!("Received shutdown signal. Exiting websocket bridge.");
381 return Ok(());
382 }
383 };
384
385 log::debug!("Client websocket connection to session {session_id} ended but the Shiny connection is still alive. Saving for reconnection.");
389 SHINY_CONNECTION_CACHE
390 .put_pack(session_id, (shiny_tx, shiny_rx))
391 .await;
392
393 tokio::select! {
395 _ = tokio::time::sleep(RECHECK_TIME) => {
396 let entry = SHINY_CONNECTION_CACHE.attempt_take(session_id).await;
397 match entry {
398 Err(_) => (),
399 Ok((shiny_tx, shiny_rx)) => {
400 let mut ws = shiny_tx
401 .reunite(shiny_rx)
402 .expect("shiny_rx and tx always have the same origin.");
403 if ws
405 .close(Some(CloseFrame {
406 code: CloseCode::Abnormal,
407 reason: Utf8Bytes::default(),
408 }))
409 .await
410 .is_ok()
411 {
412 log::debug!("Closed reserved connection for session {session_id}");
413 }
414 SHINY_CONNECTION_CACHE.remove_session(session_id).await;
415 }
416 }
417 },
418 _ = shutdown.wait() => {
419 log::debug!("Shutdown signaled, not running websocket cleanup for session {session_id}");
420 }
421 }
422
423 Ok(())
424}
425
426pub enum UpgradeStatus<ReqBody> {
427 Upgraded(Response<ExclusiveBody>),
428 NotUpgraded(Request<ReqBody>),
429}
430
431const SESSION_ID_QUERY: &str = "sessionId";
432
433fn case_insensitive_eq(this: &str, that: &str) -> bool {
435 if this.len() != that.len() {
436 return false;
437 }
438 this.bytes()
439 .zip(that.bytes())
440 .all(|(a, b)| a.to_ascii_lowercase() == b.to_ascii_lowercase())
441}
442
443async fn upgrade_connection_from_request<ReqBody>(
444 mut req: Request<ReqBody>,
445 client: impl ExtractSocketAddr,
446 shutdown: &'static ShutdownSignal,
447 websocket_config: &'static WebSocketConfig,
448) -> FaucetResult<()> {
449 let query = req.uri().query().ok_or(FaucetError::BadRequest(
451 BadRequestReason::MissingQueryParam("Unable to parse query params"),
452 ))?;
453
454 let mut session_id: Option<uuid::Uuid> = None;
455 let mut attempt: Option<usize> = None;
456
457 url::form_urlencoded::parse(query.as_bytes()).for_each(|(key, value)| {
458 if case_insensitive_eq(&key, SESSION_ID_QUERY) {
459 session_id = uuid::Uuid::from_str(&value).ok();
460 } else if case_insensitive_eq(&key, "attempt") {
461 attempt = value.parse::<usize>().ok();
462 }
463 });
464
465 let session_id = session_id.ok_or(FaucetError::BadRequest(
466 BadRequestReason::MissingQueryParam("sessionId"),
467 ))?;
468
469 let attempt = attempt.ok_or(FaucetError::BadRequest(
470 BadRequestReason::MissingQueryParam("attempt"),
471 ))?;
472
473 let upgrade_info = UpgradeInfo::new(&req, client.socket_addr())?;
474 let upgraded = hyper::upgrade::on(&mut req).await?;
475 server_upgraded_io(
476 upgraded,
477 upgrade_info,
478 session_id,
479 attempt,
480 shutdown,
481 websocket_config,
482 )
483 .await?;
484 Ok(())
485}
486
487async fn init_upgrade<ReqBody: Send + Sync + 'static>(
488 req: Request<ReqBody>,
489 client: impl ExtractSocketAddr + Send + Sync + 'static,
490 shutdown: &'static ShutdownSignal,
491 websocket_config: &'static WebSocketConfig,
492) -> FaucetResult<Response<ExclusiveBody>> {
493 let mut res = Response::new(ExclusiveBody::empty());
494 let sec_websocket_key = req
495 .headers()
496 .get(SEC_WEBSOCKET_KEY)
497 .cloned()
498 .ok_or(FaucetError::no_sec_web_socket_key())?;
499 tokio::task::spawn(async move {
500 add_connection();
501 if let Err(e) =
502 upgrade_connection_from_request(req, client, shutdown, websocket_config).await
503 {
504 log::error!("upgrade error: {e:?}");
505 }
506 remove_connection();
507 });
508 *res.status_mut() = StatusCode::SWITCHING_PROTOCOLS;
509 res.headers_mut()
510 .insert(UPGRADE, HeaderValue::from_static("websocket"));
511 res.headers_mut().insert(
512 hyper::header::CONNECTION,
513 HeaderValue::from_static("Upgrade"),
514 );
515 let mut buffer = [0u8; 32];
516 res.headers_mut().insert(
517 SEC_WEBSOCKET_ACCEPT,
518 HeaderValue::from_bytes(calculate_sec_websocket_accept(
519 sec_websocket_key.as_bytes(),
520 &mut buffer,
521 ))?,
522 );
523 Ok(res)
524}
525
526#[inline(always)]
527async fn attempt_upgrade<ReqBody: Send + Sync + 'static>(
528 req: Request<ReqBody>,
529 client: impl ExtractSocketAddr + Send + Sync + 'static,
530 shutdown: &'static ShutdownSignal,
531 websocket_config: &'static WebSocketConfig,
532) -> FaucetResult<UpgradeStatus<ReqBody>> {
533 if req.headers().contains_key(UPGRADE) {
534 return Ok(UpgradeStatus::Upgraded(
535 init_upgrade(req, client, shutdown, websocket_config).await?,
536 ));
537 }
538 Ok(UpgradeStatus::NotUpgraded(req))
539}
540
541impl Client {
542 pub async fn attempt_upgrade<ReqBody>(
543 &self,
544 req: Request<ReqBody>,
545 shutdown: &'static ShutdownSignal,
546 websocket_config: &'static WebSocketConfig,
547 ) -> FaucetResult<UpgradeStatus<ReqBody>>
548 where
549 ReqBody: Send + Sync + 'static,
550 {
551 attempt_upgrade(req, self.clone(), shutdown, websocket_config).await
552 }
553}
554
555#[cfg(test)]
556mod tests {
557 use crate::{leak, networking::get_available_socket, shutdown::ShutdownSignal};
558
559 use super::*;
560 use uuid::Uuid;
561
562 #[test]
563 fn test_insensitive_compare() {
564 let session_id = "sessionid";
565 assert!(case_insensitive_eq(session_id, SESSION_ID_QUERY));
566 }
567
568 #[test]
569 fn test_calculate_sec_websocket_accept() {
570 let key = "dGhlIHNhbXBsZSBub25jZQ==";
571 let mut buffer = [0u8; 32];
572 let accept = calculate_sec_websocket_accept(key.as_bytes(), &mut buffer);
573 assert_eq!(accept, b"s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
574 }
575
576 #[test]
577 fn test_build_uri() {
578 let socket_addr = "127.0.0.1:8000".parse().unwrap();
579 let path_and_query = "/websocket".parse().unwrap();
580 let path = Some(&path_and_query);
581 let result = build_uri(socket_addr, path).unwrap();
582 assert_eq!(result, "ws://127.0.0.1:8000/websocket");
583 }
584
585 #[test]
586 fn build_uri_no_path() {
587 let socket_addr = "127.0.0.1:8000".parse().unwrap();
588 let path = None;
589 let result = build_uri(socket_addr, path).unwrap();
590 assert_eq!(result, "ws://127.0.0.1:8000");
591 }
592
593 #[tokio::test]
594 async fn test_init_upgrade_from_request() {
595 struct MockClient {
596 socket_addr: SocketAddr,
597 }
598
599 impl ExtractSocketAddr for MockClient {
600 fn socket_addr(&self) -> SocketAddr {
601 self.socket_addr
602 }
603 }
604
605 let websocket_config = leak!(WebSocketConfig::default());
606
607 let socket_addr = get_available_socket(20).await.unwrap();
608
609 let client = MockClient { socket_addr };
610
611 let server = tokio::spawn(async move {
612 dummy_websocket_server::run(socket_addr).await.unwrap();
613 });
614
615 let uri = Uri::builder()
616 .scheme("http")
617 .authority(socket_addr.to_string().as_str())
618 .path_and_query(format!("/?{}={}", SESSION_ID_QUERY, Uuid::now_v7()))
619 .build()
620 .unwrap();
621
622 let req = Request::builder()
623 .uri(uri.clone())
624 .header(UPGRADE, "websocket")
625 .header("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
626 .body(())
627 .unwrap();
628
629 let shutdown = leak!(ShutdownSignal::new());
630 let result = init_upgrade(req, client, shutdown, websocket_config)
631 .await
632 .unwrap();
633
634 server.abort();
635
636 assert_eq!(result.status(), StatusCode::SWITCHING_PROTOCOLS);
637 assert_eq!(
638 result.headers().get(UPGRADE).unwrap(),
639 HeaderValue::from_static("websocket")
640 );
641 assert_eq!(
642 result.headers().get(SEC_WEBSOCKET_ACCEPT).unwrap(),
643 HeaderValue::from_static("s3pPLMBiTxaQ9kYGzzhZRbK+xOo=")
644 );
645 assert_eq!(
646 result.headers().get(hyper::header::CONNECTION).unwrap(),
647 HeaderValue::from_static("Upgrade")
648 );
649 }
650
651 #[tokio::test]
652 async fn test_init_upgrade_from_request_no_sec_key() {
653 struct MockClient {
654 socket_addr: SocketAddr,
655 }
656
657 impl ExtractSocketAddr for MockClient {
658 fn socket_addr(&self) -> SocketAddr {
659 self.socket_addr
660 }
661 }
662
663 let websocket_config = leak!(WebSocketConfig::default());
664
665 let socket_addr = get_available_socket(20).await.unwrap();
666
667 let client = MockClient { socket_addr };
668
669 let server = tokio::spawn(async move {
670 dummy_websocket_server::run(socket_addr).await.unwrap();
671 });
672
673 let uri = Uri::builder()
674 .scheme("http")
675 .authority(socket_addr.to_string().as_str())
676 .path_and_query(format!("/?{}={}", SESSION_ID_QUERY, Uuid::now_v7()))
677 .build()
678 .unwrap();
679
680 let req = Request::builder()
681 .uri(uri.clone())
682 .header(UPGRADE, "websocket")
683 .body(())
684 .unwrap();
685
686 let shutdown = leak!(ShutdownSignal::new());
687 let result = init_upgrade(req, client, shutdown, websocket_config).await;
688
689 server.abort();
690
691 assert!(result.is_err());
692 }
693
694 #[tokio::test]
695 async fn test_attempt_upgrade_no_upgrade_header() {
696 struct MockClient {
697 socket_addr: SocketAddr,
698 }
699
700 impl ExtractSocketAddr for MockClient {
701 fn socket_addr(&self) -> SocketAddr {
702 self.socket_addr
703 }
704 }
705
706 let socket_addr = get_available_socket(20).await.unwrap();
707 let websocket_config = leak!(WebSocketConfig::default());
708
709 let client = MockClient { socket_addr };
710
711 let server = tokio::spawn(async move {
712 dummy_websocket_server::run(socket_addr).await.unwrap();
713 });
714
715 let uri = Uri::builder()
716 .scheme("http")
717 .authority(socket_addr.to_string().as_str())
718 .path_and_query("/")
719 .build()
720 .unwrap();
721
722 let req = Request::builder()
723 .uri(uri)
724 .header("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
725 .body(())
726 .unwrap();
727
728 let shutdown = leak!(ShutdownSignal::new());
729 let result = attempt_upgrade(req, client, shutdown, websocket_config)
730 .await
731 .unwrap();
732
733 server.abort();
734
735 match result {
736 UpgradeStatus::NotUpgraded(_) => {}
737 _ => panic!("Expected NotUpgraded"),
738 }
739 }
740
741 #[tokio::test]
742 async fn test_attempt_upgrade_with_upgrade_header() {
743 struct MockClient {
744 socket_addr: SocketAddr,
745 }
746
747 impl ExtractSocketAddr for MockClient {
748 fn socket_addr(&self) -> SocketAddr {
749 self.socket_addr
750 }
751 }
752
753 let websocket_config = leak!(WebSocketConfig::default());
754
755 let socket_addr = get_available_socket(20).await.unwrap();
756
757 let client = MockClient { socket_addr };
758
759 let server = tokio::spawn(async move {
760 dummy_websocket_server::run(socket_addr).await.unwrap();
761 });
762
763 let uri = Uri::builder()
764 .scheme("http")
765 .authority(socket_addr.to_string().as_str())
766 .path_and_query(format!("/?{}={}", SESSION_ID_QUERY, Uuid::now_v7()))
767 .build()
768 .unwrap();
769
770 let req = Request::builder()
771 .uri(uri)
772 .header("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
773 .header(UPGRADE, "websocket")
774 .body(())
775 .unwrap();
776
777 let shutdown = leak!(ShutdownSignal::new());
778 let result = attempt_upgrade(req, client, shutdown, websocket_config)
779 .await
780 .unwrap();
781
782 server.abort();
783
784 match result {
785 UpgradeStatus::Upgraded(res) => {
786 assert_eq!(res.status(), StatusCode::SWITCHING_PROTOCOLS);
787 assert_eq!(
788 res.headers().get(UPGRADE).unwrap(),
789 HeaderValue::from_static("websocket")
790 );
791 assert_eq!(
792 res.headers().get(SEC_WEBSOCKET_ACCEPT).unwrap(),
793 HeaderValue::from_static("s3pPLMBiTxaQ9kYGzzhZRbK+xOo=")
794 );
795 assert_eq!(
796 res.headers().get(hyper::header::CONNECTION).unwrap(),
797 HeaderValue::from_static("Upgrade")
798 );
799 }
800 _ => panic!("Expected Upgraded"),
801 }
802 }
803
804 mod dummy_websocket_server {
805 use std::{io::Error, net::SocketAddr};
806
807 use futures_util::{future, StreamExt, TryStreamExt};
808 use log::info;
809 use tokio::net::{TcpListener, TcpStream};
810
811 pub async fn run(addr: SocketAddr) -> Result<(), Error> {
812 let try_socket = TcpListener::bind(&addr).await;
814 let listener = try_socket.expect("Failed to bind");
815 info!("Listening on: {addr}");
816
817 while let Ok((stream, _)) = listener.accept().await {
818 tokio::spawn(accept_connection(stream));
819 }
820
821 Ok(())
822 }
823
824 async fn accept_connection(stream: TcpStream) {
825 let addr = stream
826 .peer_addr()
827 .expect("connected streams should have a peer address");
828 info!("Peer address: {addr}");
829
830 let ws_stream = tokio_tungstenite::accept_async(stream)
831 .await
832 .expect("Error during the websocket handshake occurred");
833
834 info!("New WebSocket connection: {addr}");
835
836 let (write, read) = ws_stream.split();
837 read.try_filter(|msg| future::ready(msg.is_text() || msg.is_binary()))
839 .forward(write)
840 .await
841 .expect("Failed to forward messages")
842 }
843 }
844}