faucet_server/client/
websockets.rs

1use super::{pool::ExtractSocketAddr, Client, ExclusiveBody};
2use crate::{
3    error::{BadRequestReason, FaucetError, FaucetResult},
4    global_conn::{add_connection, remove_connection},
5    server::logging::{EventLogData, FaucetTracingLevel},
6    shutdown::ShutdownSignal,
7    telemetry::send_log_event,
8};
9use base64::Engine;
10use bytes::Bytes;
11use futures_util::StreamExt;
12use hyper::{
13    header::UPGRADE,
14    http::{uri::PathAndQuery, HeaderValue},
15    upgrade::Upgraded,
16    HeaderMap, Request, Response, StatusCode, Uri,
17};
18use hyper_util::rt::TokioIo;
19use serde_json::json;
20use sha1::{Digest, Sha1};
21use std::{
22    collections::HashMap, future::Future, net::SocketAddr, str::FromStr, sync::LazyLock,
23    time::Duration,
24};
25use tokio::sync::Mutex;
26use tokio_tungstenite::tungstenite::{
27    protocol::{frame::coding::CloseCode, CloseFrame},
28    Message, Utf8Bytes,
29};
30use uuid::Uuid;
31
32struct UpgradeInfo {
33    headers: HeaderMap,
34    uri: Uri,
35}
36
37impl UpgradeInfo {
38    fn new<ReqBody>(req: &Request<ReqBody>, socket_addr: SocketAddr) -> FaucetResult<Self> {
39        let headers = req.headers().clone();
40        let uri = build_uri(socket_addr, req.uri().path_and_query())?;
41        Ok(Self { headers, uri })
42    }
43}
44
45const SEC_WEBSOCKET_APPEND: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
46const SEC_WEBSOCKET_KEY: &str = "Sec-WebSocket-Key";
47const SEC_WEBSOCKET_ACCEPT: &str = "Sec-WebSocket-Accept";
48
49fn calculate_sec_websocket_accept<'buffer>(key: &[u8], buffer: &'buffer mut [u8]) -> &'buffer [u8] {
50    let mut hasher = Sha1::new();
51    hasher.update(key);
52    hasher.update(SEC_WEBSOCKET_APPEND);
53    let len = base64::engine::general_purpose::STANDARD
54        .encode_slice(hasher.finalize(), buffer)
55        .expect("Should always write the internal buffer");
56    &buffer[..len]
57}
58
59fn build_uri(socket_addr: SocketAddr, path: Option<&PathAndQuery>) -> FaucetResult<Uri> {
60    let mut uri_builder = Uri::builder()
61        .scheme("ws")
62        .authority(socket_addr.to_string());
63    match path {
64        Some(path) => uri_builder = uri_builder.path_and_query(path.clone()),
65        None => uri_builder = uri_builder.path_and_query("/"),
66    }
67    Ok(uri_builder.build()?)
68}
69
70// We want to keep the shiny tx and rx in memory in case the upgraded connection is dropped. If the user reconnect we want to immediately
71// re establish the connection back to shiny
72use futures_util::SinkExt;
73
74type ConnectionPair = (
75    futures_util::stream::SplitSink<
76        tokio_tungstenite::WebSocketStream<
77            tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
78        >,
79        tokio_tungstenite::tungstenite::Message,
80    >,
81    futures_util::stream::SplitStream<
82        tokio_tungstenite::WebSocketStream<
83            tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
84        >,
85    >,
86);
87
88#[derive(Default)]
89struct ConnectionInstance {
90    purged: bool,
91    access_count: usize,
92    pair: Option<ConnectionPair>,
93}
94
95impl ConnectionInstance {
96    fn take(&mut self) -> ConnectionPair {
97        self.access_count += 1;
98        self.pair.take().unwrap()
99    }
100    fn put_back(&mut self, pair: ConnectionPair) {
101        self.access_count += 1;
102        self.pair = Some(pair);
103    }
104}
105
106struct ConnectionManagerInner {
107    map: HashMap<Uuid, ConnectionInstance>,
108    purge_count: usize,
109}
110
111struct ConnectionManager {
112    inner: Mutex<ConnectionManagerInner>,
113}
114
115impl ConnectionManager {
116    fn new() -> Self {
117        ConnectionManager {
118            inner: Mutex::new(ConnectionManagerInner {
119                map: HashMap::new(),
120                purge_count: 0,
121            }),
122        }
123    }
124    async fn initialize_if_not(
125        &self,
126        session_id: Uuid,
127        attempt: usize,
128        init: impl Future<Output = FaucetResult<ConnectionPair>>,
129    ) -> Option<FaucetResult<ConnectionPair>> {
130        {
131            let mut inner = self.inner.lock().await;
132            let entry = inner.map.entry(session_id).or_default();
133            if entry.access_count != 0 {
134                return None;
135            }
136            if entry.purged {
137                return Some(Err(FaucetError::WebSocketConnectionPurged));
138            }
139
140            if entry.access_count == 0 && attempt > 0 {
141                return Some(Err(FaucetError::WebSocketConnectionPurged));
142            }
143
144            entry.access_count += 1;
145        }
146        let connection_pair = match init.await {
147            Ok(connection_pair) => connection_pair,
148            Err(e) => return Some(Err(e)),
149        };
150        Some(Ok(connection_pair))
151    }
152    async fn attempt_take(&self, session_id: Uuid) -> FaucetResult<ConnectionPair> {
153        match self.inner.try_lock() {
154            Ok(mut inner) => {
155                let instance = inner.map.entry(session_id).or_default();
156
157                if instance.access_count % 2 == 0 {
158                    return Ok(instance.take());
159                }
160
161                Err(FaucetError::WebSocketConnectionInUse)
162            }
163            _ => Err(FaucetError::WebSocketConnectionInUse),
164        }
165    }
166    async fn put_pack(&self, session_id: Uuid, pair: ConnectionPair) {
167        let mut inner = self.inner.lock().await;
168        if let Some(instance) = inner.map.get_mut(&session_id) {
169            instance.put_back(pair);
170        }
171    }
172    async fn remove_session(&self, session_id: Uuid) {
173        let mut inner = self.inner.lock().await;
174        inner.map.remove(&session_id);
175        inner.purge_count += 1;
176        if let Some(instance) = inner.map.get_mut(&session_id) {
177            instance.purged = true;
178        }
179    }
180}
181
182// Note: This is a simplified cache for a single shiny connection using a static Mutex.
183// A more robust solution would use a session identifier to cache multiple connections.
184// We use a std::sync::Mutex as the lock is not held across .await points.
185static SHINY_CONNECTION_CACHE: LazyLock<ConnectionManager> = LazyLock::new(ConnectionManager::new);
186
187async fn connect_to_worker(
188    mut upgrade_info: UpgradeInfo,
189    session_id: Uuid,
190) -> FaucetResult<ConnectionPair> {
191    let mut request = Request::builder().uri(upgrade_info.uri).body(())?;
192    upgrade_info.headers.append(
193        "FAUCET_SESSION_ID",
194        HeaderValue::from_str(&session_id.to_string())
195            .expect("Unable to set Session ID as header. This is a bug. please report it!"),
196    );
197    *request.headers_mut() = upgrade_info.headers;
198    let (shiny_ws, _) = tokio_tungstenite::connect_async(request).await?;
199    send_log_event(EventLogData {
200        target: "faucet".into(),
201        event_id: session_id,
202        parent_event_id: None,
203        level: FaucetTracingLevel::Info,
204        event_type: "websocket_connection".into(),
205        message: "Established new WebSocket connection to shiny".to_string(),
206        body: None,
207    });
208    Ok(shiny_ws.split())
209}
210
211async fn connect_or_retrieve(
212    upgrade_info: UpgradeInfo,
213    session_id: Uuid,
214    attempt: usize,
215) -> FaucetResult<ConnectionPair> {
216    let init_pair = SHINY_CONNECTION_CACHE
217        .initialize_if_not(
218            session_id,
219            attempt,
220            connect_to_worker(upgrade_info, session_id),
221        )
222        .await;
223
224    match init_pair {
225        None => {
226            // This means that the connection has already been initialized
227            // in the past
228            match SHINY_CONNECTION_CACHE.attempt_take(session_id).await {
229                Ok(con) => {
230                    send_log_event(EventLogData {
231                        target: "faucet".into(),
232                        event_id: Uuid::new_v4(),
233                        parent_event_id: Some(session_id),
234                        event_type: "websocket_connection".into(),
235                        level: FaucetTracingLevel::Info,
236                        message: "Client successfully reconnected".to_string(),
237                        body: Some(json!({"attempts": attempt})),
238                    });
239                    Ok(con)
240                }
241                Err(e) => FaucetResult::Err(e),
242            }
243        }
244        Some(init_pair_res) => init_pair_res,
245    }
246}
247
248const RECHECK_TIME: Duration = Duration::from_secs(60);
249const PING_INTERVAL: Duration = Duration::from_secs(1);
250const PING_INTERVAL_TIMEOUT: Duration = Duration::from_secs(30);
251const PING_BYTES: Bytes = Bytes::from_static(b"Ping");
252
253async fn server_upgraded_io(
254    upgraded: Upgraded,
255    upgrade_info: UpgradeInfo,
256    session_id: Uuid,
257    attempt: usize,
258    shutdown: &'static ShutdownSignal,
259) -> FaucetResult<()> {
260    // Set up the WebSocket connection with the client.
261    let upgraded = TokioIo::new(upgraded);
262    let upgraded_ws = tokio_tungstenite::WebSocketStream::from_raw_socket(
263        upgraded,
264        tokio_tungstenite::tungstenite::protocol::Role::Server,
265        None,
266    )
267    .await;
268    let (mut upgraded_tx, mut upgraded_rx) = upgraded_ws.split();
269
270    // Attempt to retrieve a cached connection to Shiny.
271    let (mut shiny_tx, mut shiny_rx) =
272        match connect_or_retrieve(upgrade_info, session_id, attempt).await {
273            Ok(pair) => pair,
274            Err(e) => match e {
275                FaucetError::WebSocketConnectionPurged => {
276                    upgraded_tx
277                        .send(Message::Close(Some(CloseFrame {
278                            code: CloseCode::Normal,
279                            reason: Utf8Bytes::from_static(
280                                "Connection purged due to inactivity, update or error.",
281                            ),
282                        })))
283                        .await?;
284                    return Err(FaucetError::WebSocketConnectionPurged);
285                }
286                e => return Err(e),
287            },
288        };
289
290    // Manually pump messages in both directions.
291    // This allows us to regain ownership of the streams after a disconnect.
292    let client_to_shiny = async {
293        loop {
294            log::debug!("Waiting for message or ping timeout");
295            tokio::select! {
296                msg = upgraded_rx.next() => {
297                    log::debug!("Received msg: {msg:?}");
298                    match msg {
299                        Some(Ok(msg)) => {
300                            if shiny_tx.send(msg).await.is_err() {
301                                break; // Shiny connection closed
302                            }
303                        },
304                        _ => break
305                    }
306                },
307                _ = tokio::time::sleep(PING_INTERVAL_TIMEOUT) => {
308                    log::debug!("Ping timeout reached for session {session_id}");
309                    break;
310                }
311            }
312        }
313    };
314
315    let shiny_to_client = async {
316        loop {
317            let ping_future = async {
318                tokio::time::sleep(PING_INTERVAL).await;
319                upgraded_tx.send(Message::Ping(PING_BYTES)).await
320            };
321            tokio::select! {
322                msg = shiny_rx.next() => {
323                    match msg {
324                        Some(Ok(msg)) => {
325                            if upgraded_tx.send(msg).await.is_err() {
326                                break; // Client connection closed
327                            }
328                        },
329                        _ => break
330                    }
331                },
332                _ = ping_future => {}
333            }
334        }
335    };
336
337    // Wait for either the client or Shiny to disconnect.
338    tokio::select! {
339        _ = client_to_shiny => {
340            send_log_event(EventLogData {
341                target: "faucet".into(),
342                event_id: Uuid::new_v4(),
343                parent_event_id: Some(session_id),
344                event_type: "websocket_connection".into(),
345                level: FaucetTracingLevel::Info,
346                message: "Session ended by client.".to_string(),
347                body: None,
348            });
349            log::debug!("Client connection closed for session {session_id}.")
350        },
351        _ = shiny_to_client => {
352            // If this happens that means shiny ended the session, immediately
353            // remove the session from the cache
354            SHINY_CONNECTION_CACHE.remove_session(session_id).await;
355            send_log_event(EventLogData {
356                target: "faucet".into(),
357                event_id: Uuid::new_v4(),
358                parent_event_id: Some(session_id),
359                event_type: "websocket_connection".into(),
360                level: FaucetTracingLevel::Info,
361                message: "Shiny session ended by Shiny.".to_string(),
362                body: None,
363            });
364            log::debug!("Shiny connection closed for session {session_id}.");
365            return Ok(());
366        },
367        _ = shutdown.wait() => {
368            log::debug!("Received shutdown signal. Exiting websocket bridge.");
369            return Ok(());
370        }
371    };
372
373    // Getting here meant that the only possible way the session ended is if
374    // the client ended the connection
375
376    log::debug!("Client websocket connection to session {session_id} ended but the Shiny connection is still alive. Saving for reconnection.");
377    SHINY_CONNECTION_CACHE
378        .put_pack(session_id, (shiny_tx, shiny_rx))
379        .await;
380
381    // Schedule a check in 30 seconds. If the connection is not in use
382    tokio::select! {
383        _ = tokio::time::sleep(RECHECK_TIME) => {
384            let entry = SHINY_CONNECTION_CACHE.attempt_take(session_id).await;
385            match entry {
386                Err(_) => (),
387                Ok((shiny_tx, shiny_rx)) => {
388                    let mut ws = shiny_tx
389                        .reunite(shiny_rx)
390                        .expect("shiny_rx and tx always have the same origin.");
391                    //
392                    if ws
393                        .close(Some(CloseFrame {
394                            code: CloseCode::Abnormal,
395                            reason: Utf8Bytes::default(),
396                        }))
397                        .await
398                        .is_ok()
399                    {
400                        log::debug!("Closed reserved connection for session {session_id}");
401                    }
402                    SHINY_CONNECTION_CACHE.remove_session(session_id).await;
403                }
404            }
405        },
406        _ = shutdown.wait() => {
407            log::debug!("Shutdown signaled, not running websocket cleanup for session {session_id}");
408        }
409    }
410
411    Ok(())
412}
413
414pub enum UpgradeStatus<ReqBody> {
415    Upgraded(Response<ExclusiveBody>),
416    NotUpgraded(Request<ReqBody>),
417}
418
419const SESSION_ID_QUERY: &str = "sessionId";
420
421async fn upgrade_connection_from_request<ReqBody>(
422    mut req: Request<ReqBody>,
423    client: impl ExtractSocketAddr,
424    shutdown: &'static ShutdownSignal,
425) -> FaucetResult<()> {
426    // Extract sessionId query parameter
427    let query = req.uri().query().ok_or(FaucetError::BadRequest(
428        BadRequestReason::MissingQueryParam("sessionId"),
429    ))?;
430
431    let mut session_id: Option<uuid::Uuid> = None;
432    let mut attempt: Option<usize> = None;
433
434    url::form_urlencoded::parse(query.as_bytes()).for_each(|(key, value)| {
435        if key == SESSION_ID_QUERY {
436            session_id = uuid::Uuid::from_str(&value).ok();
437        } else if key == "attempt" {
438            attempt = value.parse::<usize>().ok();
439        }
440    });
441
442    let session_id = session_id.ok_or(FaucetError::BadRequest(
443        BadRequestReason::MissingQueryParam("sessionId"),
444    ))?;
445
446    let attempt = attempt.ok_or(FaucetError::BadRequest(
447        BadRequestReason::MissingQueryParam("attempt"),
448    ))?;
449
450    let upgrade_info = UpgradeInfo::new(&req, client.socket_addr())?;
451    let upgraded = hyper::upgrade::on(&mut req).await?;
452    server_upgraded_io(upgraded, upgrade_info, session_id, attempt, shutdown).await?;
453    Ok(())
454}
455
456async fn init_upgrade<ReqBody: Send + Sync + 'static>(
457    req: Request<ReqBody>,
458    client: impl ExtractSocketAddr + Send + Sync + 'static,
459    shutdown: &'static ShutdownSignal,
460) -> FaucetResult<Response<ExclusiveBody>> {
461    let mut res = Response::new(ExclusiveBody::empty());
462    let sec_websocket_key = req
463        .headers()
464        .get(SEC_WEBSOCKET_KEY)
465        .cloned()
466        .ok_or(FaucetError::no_sec_web_socket_key())?;
467    tokio::task::spawn(async move {
468        add_connection();
469        if let Err(e) = upgrade_connection_from_request(req, client, shutdown).await {
470            log::error!("upgrade error: {e:?}");
471        }
472        remove_connection();
473    });
474    *res.status_mut() = StatusCode::SWITCHING_PROTOCOLS;
475    res.headers_mut()
476        .insert(UPGRADE, HeaderValue::from_static("websocket"));
477    res.headers_mut().insert(
478        hyper::header::CONNECTION,
479        HeaderValue::from_static("Upgrade"),
480    );
481    let mut buffer = [0u8; 32];
482    res.headers_mut().insert(
483        SEC_WEBSOCKET_ACCEPT,
484        HeaderValue::from_bytes(calculate_sec_websocket_accept(
485            sec_websocket_key.as_bytes(),
486            &mut buffer,
487        ))?,
488    );
489    Ok(res)
490}
491
492#[inline(always)]
493async fn attempt_upgrade<ReqBody: Send + Sync + 'static>(
494    req: Request<ReqBody>,
495    client: impl ExtractSocketAddr + Send + Sync + 'static,
496    shutdown: &'static ShutdownSignal,
497) -> FaucetResult<UpgradeStatus<ReqBody>> {
498    if req.headers().contains_key(UPGRADE) {
499        return Ok(UpgradeStatus::Upgraded(
500            init_upgrade(req, client, shutdown).await?,
501        ));
502    }
503    Ok(UpgradeStatus::NotUpgraded(req))
504}
505
506impl Client {
507    pub async fn attempt_upgrade<ReqBody>(
508        &self,
509        req: Request<ReqBody>,
510        shutdown: &'static ShutdownSignal,
511    ) -> FaucetResult<UpgradeStatus<ReqBody>>
512    where
513        ReqBody: Send + Sync + 'static,
514    {
515        attempt_upgrade(req, self.clone(), shutdown).await
516    }
517}
518
519#[cfg(test)]
520mod tests {
521    use crate::{leak, networking::get_available_socket, shutdown::ShutdownSignal};
522
523    use super::*;
524    use uuid::Uuid;
525
526    #[test]
527    fn test_calculate_sec_websocket_accept() {
528        let key = "dGhlIHNhbXBsZSBub25jZQ==";
529        let mut buffer = [0u8; 32];
530        let accept = calculate_sec_websocket_accept(key.as_bytes(), &mut buffer);
531        assert_eq!(accept, b"s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
532    }
533
534    #[test]
535    fn test_build_uri() {
536        let socket_addr = "127.0.0.1:8000".parse().unwrap();
537        let path_and_query = "/websocket".parse().unwrap();
538        let path = Some(&path_and_query);
539        let result = build_uri(socket_addr, path).unwrap();
540        assert_eq!(result, "ws://127.0.0.1:8000/websocket");
541    }
542
543    #[test]
544    fn build_uri_no_path() {
545        let socket_addr = "127.0.0.1:8000".parse().unwrap();
546        let path = None;
547        let result = build_uri(socket_addr, path).unwrap();
548        assert_eq!(result, "ws://127.0.0.1:8000");
549    }
550
551    #[tokio::test]
552    async fn test_init_upgrade_from_request() {
553        struct MockClient {
554            socket_addr: SocketAddr,
555        }
556
557        impl ExtractSocketAddr for MockClient {
558            fn socket_addr(&self) -> SocketAddr {
559                self.socket_addr
560            }
561        }
562
563        let socket_addr = get_available_socket(20).await.unwrap();
564
565        let client = MockClient { socket_addr };
566
567        let server = tokio::spawn(async move {
568            dummy_websocket_server::run(socket_addr).await.unwrap();
569        });
570
571        let uri = Uri::builder()
572            .scheme("http")
573            .authority(socket_addr.to_string().as_str())
574            .path_and_query(format!("/?{}={}", SESSION_ID_QUERY, Uuid::now_v7()))
575            .build()
576            .unwrap();
577
578        let req = Request::builder()
579            .uri(uri.clone())
580            .header(UPGRADE, "websocket")
581            .header("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
582            .body(())
583            .unwrap();
584
585        let shutdown = leak!(ShutdownSignal::new());
586        let result = init_upgrade(req, client, shutdown).await.unwrap();
587
588        server.abort();
589
590        assert_eq!(result.status(), StatusCode::SWITCHING_PROTOCOLS);
591        assert_eq!(
592            result.headers().get(UPGRADE).unwrap(),
593            HeaderValue::from_static("websocket")
594        );
595        assert_eq!(
596            result.headers().get(SEC_WEBSOCKET_ACCEPT).unwrap(),
597            HeaderValue::from_static("s3pPLMBiTxaQ9kYGzzhZRbK+xOo=")
598        );
599        assert_eq!(
600            result.headers().get(hyper::header::CONNECTION).unwrap(),
601            HeaderValue::from_static("Upgrade")
602        );
603    }
604
605    #[tokio::test]
606    async fn test_init_upgrade_from_request_no_sec_key() {
607        struct MockClient {
608            socket_addr: SocketAddr,
609        }
610
611        impl ExtractSocketAddr for MockClient {
612            fn socket_addr(&self) -> SocketAddr {
613                self.socket_addr
614            }
615        }
616
617        let socket_addr = get_available_socket(20).await.unwrap();
618
619        let client = MockClient { socket_addr };
620
621        let server = tokio::spawn(async move {
622            dummy_websocket_server::run(socket_addr).await.unwrap();
623        });
624
625        let uri = Uri::builder()
626            .scheme("http")
627            .authority(socket_addr.to_string().as_str())
628            .path_and_query(format!("/?{}={}", SESSION_ID_QUERY, Uuid::now_v7()))
629            .build()
630            .unwrap();
631
632        let req = Request::builder()
633            .uri(uri.clone())
634            .header(UPGRADE, "websocket")
635            .body(())
636            .unwrap();
637
638        let shutdown = leak!(ShutdownSignal::new());
639        let result = init_upgrade(req, client, shutdown).await;
640
641        server.abort();
642
643        assert!(result.is_err());
644    }
645
646    #[tokio::test]
647    async fn test_attempt_upgrade_no_upgrade_header() {
648        struct MockClient {
649            socket_addr: SocketAddr,
650        }
651
652        impl ExtractSocketAddr for MockClient {
653            fn socket_addr(&self) -> SocketAddr {
654                self.socket_addr
655            }
656        }
657
658        let socket_addr = get_available_socket(20).await.unwrap();
659
660        let client = MockClient { socket_addr };
661
662        let server = tokio::spawn(async move {
663            dummy_websocket_server::run(socket_addr).await.unwrap();
664        });
665
666        let uri = Uri::builder()
667            .scheme("http")
668            .authority(socket_addr.to_string().as_str())
669            .path_and_query("/")
670            .build()
671            .unwrap();
672
673        let req = Request::builder()
674            .uri(uri)
675            .header("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
676            .body(())
677            .unwrap();
678
679        let shutdown = leak!(ShutdownSignal::new());
680        let result = attempt_upgrade(req, client, shutdown).await.unwrap();
681
682        server.abort();
683
684        match result {
685            UpgradeStatus::NotUpgraded(_) => {}
686            _ => panic!("Expected NotUpgraded"),
687        }
688    }
689
690    #[tokio::test]
691    async fn test_attempt_upgrade_with_upgrade_header() {
692        struct MockClient {
693            socket_addr: SocketAddr,
694        }
695
696        impl ExtractSocketAddr for MockClient {
697            fn socket_addr(&self) -> SocketAddr {
698                self.socket_addr
699            }
700        }
701
702        let socket_addr = get_available_socket(20).await.unwrap();
703
704        let client = MockClient { socket_addr };
705
706        let server = tokio::spawn(async move {
707            dummy_websocket_server::run(socket_addr).await.unwrap();
708        });
709
710        let uri = Uri::builder()
711            .scheme("http")
712            .authority(socket_addr.to_string().as_str())
713            .path_and_query(format!("/?{}={}", SESSION_ID_QUERY, Uuid::now_v7()))
714            .build()
715            .unwrap();
716
717        let req = Request::builder()
718            .uri(uri)
719            .header("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
720            .header(UPGRADE, "websocket")
721            .body(())
722            .unwrap();
723
724        let shutdown = leak!(ShutdownSignal::new());
725        let result = attempt_upgrade(req, client, shutdown).await.unwrap();
726
727        server.abort();
728
729        match result {
730            UpgradeStatus::Upgraded(res) => {
731                assert_eq!(res.status(), StatusCode::SWITCHING_PROTOCOLS);
732                assert_eq!(
733                    res.headers().get(UPGRADE).unwrap(),
734                    HeaderValue::from_static("websocket")
735                );
736                assert_eq!(
737                    res.headers().get(SEC_WEBSOCKET_ACCEPT).unwrap(),
738                    HeaderValue::from_static("s3pPLMBiTxaQ9kYGzzhZRbK+xOo=")
739                );
740                assert_eq!(
741                    res.headers().get(hyper::header::CONNECTION).unwrap(),
742                    HeaderValue::from_static("Upgrade")
743                );
744            }
745            _ => panic!("Expected Upgraded"),
746        }
747    }
748
749    mod dummy_websocket_server {
750        use std::{io::Error, net::SocketAddr};
751
752        use futures_util::{future, StreamExt, TryStreamExt};
753        use log::info;
754        use tokio::net::{TcpListener, TcpStream};
755
756        pub async fn run(addr: SocketAddr) -> Result<(), Error> {
757            // Create the event loop and TCP listener we'll accept connections on.
758            let try_socket = TcpListener::bind(&addr).await;
759            let listener = try_socket.expect("Failed to bind");
760            info!("Listening on: {addr}");
761
762            while let Ok((stream, _)) = listener.accept().await {
763                tokio::spawn(accept_connection(stream));
764            }
765
766            Ok(())
767        }
768
769        async fn accept_connection(stream: TcpStream) {
770            let addr = stream
771                .peer_addr()
772                .expect("connected streams should have a peer address");
773            info!("Peer address: {addr}");
774
775            let ws_stream = tokio_tungstenite::accept_async(stream)
776                .await
777                .expect("Error during the websocket handshake occurred");
778
779            info!("New WebSocket connection: {addr}");
780
781            let (write, read) = ws_stream.split();
782            // We should not forward messages other than text or binary.
783            read.try_filter(|msg| future::ready(msg.is_text() || msg.is_binary()))
784                .forward(write)
785                .await
786                .expect("Failed to forward messages")
787        }
788    }
789}