faucet_server/client/
websockets.rs

1use super::{pool::ExtractSocketAddr, Client, ExclusiveBody};
2use crate::{
3    error::{BadRequestReason, FaucetError, FaucetResult},
4    global_conn::{add_connection, remove_connection},
5    server::logging::{EventLogData, FaucetTracingLevel},
6    shutdown::ShutdownSignal,
7    telemetry::send_log_event,
8};
9use base64::Engine;
10use bytes::Bytes;
11use futures_util::StreamExt;
12use hyper::{
13    header::UPGRADE,
14    http::{uri::PathAndQuery, HeaderValue},
15    upgrade::Upgraded,
16    HeaderMap, Request, Response, StatusCode, Uri,
17};
18use hyper_util::rt::TokioIo;
19use serde_json::json;
20use sha1::{Digest, Sha1};
21use std::{
22    collections::HashMap, future::Future, net::SocketAddr, str::FromStr, sync::LazyLock,
23    time::Duration,
24};
25use tokio::sync::Mutex;
26use tokio_tungstenite::tungstenite::{
27    protocol::{frame::coding::CloseCode, CloseFrame, WebSocketConfig},
28    Message, Utf8Bytes,
29};
30use uuid::Uuid;
31
32struct UpgradeInfo {
33    headers: HeaderMap,
34    uri: Uri,
35}
36
37impl UpgradeInfo {
38    fn new<ReqBody>(req: &Request<ReqBody>, socket_addr: SocketAddr) -> FaucetResult<Self> {
39        let headers = req.headers().clone();
40        let uri = build_uri(socket_addr, req.uri().path_and_query())?;
41        Ok(Self { headers, uri })
42    }
43}
44
45const SEC_WEBSOCKET_APPEND: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
46const SEC_WEBSOCKET_KEY: &str = "Sec-WebSocket-Key";
47const SEC_WEBSOCKET_ACCEPT: &str = "Sec-WebSocket-Accept";
48
49fn calculate_sec_websocket_accept<'buffer>(key: &[u8], buffer: &'buffer mut [u8]) -> &'buffer [u8] {
50    let mut hasher = Sha1::new();
51    hasher.update(key);
52    hasher.update(SEC_WEBSOCKET_APPEND);
53    let len = base64::engine::general_purpose::STANDARD
54        .encode_slice(hasher.finalize(), buffer)
55        .expect("Should always write the internal buffer");
56    &buffer[..len]
57}
58
59fn build_uri(socket_addr: SocketAddr, path: Option<&PathAndQuery>) -> FaucetResult<Uri> {
60    let mut uri_builder = Uri::builder()
61        .scheme("ws")
62        .authority(socket_addr.to_string());
63    match path {
64        Some(path) => uri_builder = uri_builder.path_and_query(path.clone()),
65        None => uri_builder = uri_builder.path_and_query("/"),
66    }
67    Ok(uri_builder.build()?)
68}
69
70// We want to keep the shiny tx and rx in memory in case the upgraded connection is dropped. If the user reconnect we want to immediately
71// re establish the connection back to shiny
72use futures_util::SinkExt;
73
74type ConnectionPair = (
75    futures_util::stream::SplitSink<
76        tokio_tungstenite::WebSocketStream<
77            tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
78        >,
79        tokio_tungstenite::tungstenite::Message,
80    >,
81    futures_util::stream::SplitStream<
82        tokio_tungstenite::WebSocketStream<
83            tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
84        >,
85    >,
86);
87
88#[derive(Default)]
89struct ConnectionInstance {
90    purged: bool,
91    access_count: usize,
92    pair: Option<ConnectionPair>,
93}
94
95impl ConnectionInstance {
96    fn take(&mut self) -> ConnectionPair {
97        self.access_count += 1;
98        self.pair.take().unwrap()
99    }
100    fn put_back(&mut self, pair: ConnectionPair) {
101        self.access_count += 1;
102        self.pair = Some(pair);
103    }
104}
105
106struct ConnectionManagerInner {
107    map: HashMap<Uuid, ConnectionInstance>,
108    purge_count: usize,
109}
110
111struct ConnectionManager {
112    inner: Mutex<ConnectionManagerInner>,
113}
114
115impl ConnectionManager {
116    fn new() -> Self {
117        ConnectionManager {
118            inner: Mutex::new(ConnectionManagerInner {
119                map: HashMap::new(),
120                purge_count: 0,
121            }),
122        }
123    }
124    async fn initialize_if_not(
125        &self,
126        session_id: Uuid,
127        attempt: usize,
128        init: impl Future<Output = FaucetResult<ConnectionPair>>,
129    ) -> Option<FaucetResult<ConnectionPair>> {
130        {
131            let mut inner = self.inner.lock().await;
132            let entry = inner.map.entry(session_id).or_default();
133            if entry.access_count != 0 {
134                return None;
135            }
136            if entry.purged {
137                return Some(Err(FaucetError::WebSocketConnectionPurged));
138            }
139
140            if entry.access_count == 0 && attempt > 0 {
141                return Some(Err(FaucetError::WebSocketConnectionPurged));
142            }
143
144            entry.access_count += 1;
145        }
146        let connection_pair = match init.await {
147            Ok(connection_pair) => connection_pair,
148            Err(e) => return Some(Err(e)),
149        };
150        Some(Ok(connection_pair))
151    }
152    async fn attempt_take(&self, session_id: Uuid) -> FaucetResult<ConnectionPair> {
153        match self.inner.try_lock() {
154            Ok(mut inner) => {
155                let instance = inner.map.entry(session_id).or_default();
156
157                if instance.access_count % 2 == 0 {
158                    return Ok(instance.take());
159                }
160
161                Err(FaucetError::WebSocketConnectionInUse)
162            }
163            _ => Err(FaucetError::WebSocketConnectionInUse),
164        }
165    }
166    async fn put_pack(&self, session_id: Uuid, pair: ConnectionPair) {
167        let mut inner = self.inner.lock().await;
168        if let Some(instance) = inner.map.get_mut(&session_id) {
169            instance.put_back(pair);
170        }
171    }
172    async fn remove_session(&self, session_id: Uuid) {
173        let mut inner = self.inner.lock().await;
174        inner.map.remove(&session_id);
175        inner.purge_count += 1;
176        if let Some(instance) = inner.map.get_mut(&session_id) {
177            instance.purged = true;
178        }
179    }
180}
181
182// Note: This is a simplified cache for a single shiny connection using a static Mutex.
183// A more robust solution would use a session identifier to cache multiple connections.
184// We use a std::sync::Mutex as the lock is not held across .await points.
185static SHINY_CONNECTION_CACHE: LazyLock<ConnectionManager> = LazyLock::new(ConnectionManager::new);
186
187async fn connect_to_worker(
188    mut upgrade_info: UpgradeInfo,
189    session_id: Uuid,
190    config: &'static WebSocketConfig,
191) -> FaucetResult<ConnectionPair> {
192    let mut request = Request::builder().uri(upgrade_info.uri).body(())?;
193    upgrade_info.headers.append(
194        "FAUCET_SESSION_ID",
195        HeaderValue::from_str(&session_id.to_string())
196            .expect("Unable to set Session ID as header. This is a bug. please report it!"),
197    );
198    *request.headers_mut() = upgrade_info.headers;
199    let (shiny_ws, _) =
200        tokio_tungstenite::connect_async_with_config(request, Some(*config), false).await?;
201    send_log_event(EventLogData {
202        target: "faucet".into(),
203        event_id: session_id,
204        parent_event_id: None,
205        level: FaucetTracingLevel::Info,
206        event_type: "websocket_connection".into(),
207        message: "Established new WebSocket connection to shiny".to_string(),
208        body: None,
209    });
210    Ok(shiny_ws.split())
211}
212
213async fn connect_or_retrieve(
214    upgrade_info: UpgradeInfo,
215    session_id: Uuid,
216    attempt: usize,
217    config: &'static WebSocketConfig,
218) -> FaucetResult<ConnectionPair> {
219    let init_pair = SHINY_CONNECTION_CACHE
220        .initialize_if_not(
221            session_id,
222            attempt,
223            connect_to_worker(upgrade_info, session_id, config),
224        )
225        .await;
226
227    match init_pair {
228        None => {
229            // This means that the connection has already been initialized
230            // in the past
231            match SHINY_CONNECTION_CACHE.attempt_take(session_id).await {
232                Ok(con) => {
233                    send_log_event(EventLogData {
234                        target: "faucet".into(),
235                        event_id: Uuid::new_v4(),
236                        parent_event_id: Some(session_id),
237                        event_type: "websocket_connection".into(),
238                        level: FaucetTracingLevel::Info,
239                        message: "Client successfully reconnected".to_string(),
240                        body: Some(json!({"attempts": attempt})),
241                    });
242                    Ok(con)
243                }
244                Err(e) => FaucetResult::Err(e),
245            }
246        }
247        Some(init_pair_res) => init_pair_res,
248    }
249}
250
251const RECHECK_TIME: Duration = Duration::from_secs(60);
252const PING_INTERVAL: Duration = Duration::from_secs(1);
253const PING_INTERVAL_TIMEOUT: Duration = Duration::from_secs(30);
254const PING_BYTES: Bytes = Bytes::from_static(b"Ping");
255
256async fn server_upgraded_io(
257    upgraded: Upgraded,
258    upgrade_info: UpgradeInfo,
259    session_id: Uuid,
260    attempt: usize,
261    shutdown: &'static ShutdownSignal,
262    websocket_config: &'static WebSocketConfig,
263) -> FaucetResult<()> {
264    // Set up the WebSocket connection with the client.
265    let upgraded = TokioIo::new(upgraded);
266    let upgraded_ws = tokio_tungstenite::WebSocketStream::from_raw_socket(
267        upgraded,
268        tokio_tungstenite::tungstenite::protocol::Role::Server,
269        Some(*websocket_config),
270    )
271    .await;
272    let (mut upgraded_tx, mut upgraded_rx) = upgraded_ws.split();
273
274    // Attempt to retrieve a cached connection to Shiny.
275    let (mut shiny_tx, mut shiny_rx) =
276        match connect_or_retrieve(upgrade_info, session_id, attempt, websocket_config).await {
277            Ok(pair) => pair,
278            Err(e) => match e {
279                FaucetError::WebSocketConnectionPurged => {
280                    upgraded_tx
281                        .send(Message::Close(Some(CloseFrame {
282                            code: CloseCode::Normal,
283                            reason: Utf8Bytes::from_static(
284                                "Connection purged due to inactivity, update or error.",
285                            ),
286                        })))
287                        .await?;
288                    return Err(FaucetError::WebSocketConnectionPurged);
289                }
290                e => return Err(e),
291            },
292        };
293
294    // Manually pump messages in both directions.
295    // This allows us to regain ownership of the streams after a disconnect.
296    let client_to_shiny = async {
297        loop {
298            log::debug!("Waiting for message or ping timeout");
299            tokio::select! {
300                msg = upgraded_rx.next() => {
301                    log::debug!("Received msg: {msg:?}");
302                    match msg {
303                        Some(Ok(msg)) => {
304                            if shiny_tx.send(msg).await.is_err() {
305                                break; // Shiny connection closed
306                            }
307                        },
308                        Some(Err(e)) => {
309                            log::error!("Error sending websocket message to shiny: {e}");
310                            break
311                        }
312                        _ => break
313                    }
314                },
315                _ = tokio::time::sleep(PING_INTERVAL_TIMEOUT) => {
316                    log::debug!("Ping timeout reached for session {session_id}");
317                    break;
318                }
319            }
320        }
321    };
322
323    let shiny_to_client = async {
324        loop {
325            let ping_future = async {
326                tokio::time::sleep(PING_INTERVAL).await;
327                upgraded_tx.send(Message::Ping(PING_BYTES)).await
328            };
329            tokio::select! {
330                msg = shiny_rx.next() => {
331                    match msg {
332                        Some(Ok(msg)) => {
333                            if upgraded_tx.send(msg).await.is_err() {
334                                break; // Client connection closed
335                            }
336                        },
337                        Some(Err(e)) => {
338                            log::error!("Error sending websocket message to client: {e}");
339                            break
340                        }
341                        _ => break
342                    }
343                },
344                _ = ping_future => {}
345            }
346        }
347    };
348
349    // Wait for either the client or Shiny to disconnect.
350    tokio::select! {
351        _ = client_to_shiny => {
352            send_log_event(EventLogData {
353                target: "faucet".into(),
354                event_id: Uuid::new_v4(),
355                parent_event_id: Some(session_id),
356                event_type: "websocket_connection".into(),
357                level: FaucetTracingLevel::Info,
358                message: "Session ended by client.".to_string(),
359                body: None,
360            });
361            log::debug!("Client connection closed for session {session_id}.")
362        },
363        _ = shiny_to_client => {
364            // If this happens that means shiny ended the session, immediately
365            // remove the session from the cache
366            SHINY_CONNECTION_CACHE.remove_session(session_id).await;
367            send_log_event(EventLogData {
368                target: "faucet".into(),
369                event_id: Uuid::new_v4(),
370                parent_event_id: Some(session_id),
371                event_type: "websocket_connection".into(),
372                level: FaucetTracingLevel::Info,
373                message: "Shiny session ended by Shiny.".to_string(),
374                body: None,
375            });
376            log::debug!("Shiny connection closed for session {session_id}.");
377            return Ok(());
378        },
379        _ = shutdown.wait() => {
380            log::debug!("Received shutdown signal. Exiting websocket bridge.");
381            return Ok(());
382        }
383    };
384
385    // Getting here meant that the only possible way the session ended is if
386    // the client ended the connection
387
388    log::debug!("Client websocket connection to session {session_id} ended but the Shiny connection is still alive. Saving for reconnection.");
389    SHINY_CONNECTION_CACHE
390        .put_pack(session_id, (shiny_tx, shiny_rx))
391        .await;
392
393    // Schedule a check in 30 seconds. If the connection is not in use
394    tokio::select! {
395        _ = tokio::time::sleep(RECHECK_TIME) => {
396            let entry = SHINY_CONNECTION_CACHE.attempt_take(session_id).await;
397            match entry {
398                Err(_) => (),
399                Ok((shiny_tx, shiny_rx)) => {
400                    let mut ws = shiny_tx
401                        .reunite(shiny_rx)
402                        .expect("shiny_rx and tx always have the same origin.");
403                    //
404                    if ws
405                        .close(Some(CloseFrame {
406                            code: CloseCode::Abnormal,
407                            reason: Utf8Bytes::default(),
408                        }))
409                        .await
410                        .is_ok()
411                    {
412                        log::debug!("Closed reserved connection for session {session_id}");
413                    }
414                    SHINY_CONNECTION_CACHE.remove_session(session_id).await;
415                }
416            }
417        },
418        _ = shutdown.wait() => {
419            log::debug!("Shutdown signaled, not running websocket cleanup for session {session_id}");
420        }
421    }
422
423    Ok(())
424}
425
426pub enum UpgradeStatus<ReqBody> {
427    Upgraded(Response<ExclusiveBody>),
428    NotUpgraded(Request<ReqBody>),
429}
430
431const SESSION_ID_QUERY: &str = "sessionId";
432
433/// zero allocation case insensitive ascii compare
434fn case_insensitive_eq(this: &str, that: &str) -> bool {
435    if this.len() != that.len() {
436        return false;
437    }
438    this.bytes()
439        .zip(that.bytes())
440        .all(|(a, b)| a.to_ascii_lowercase() == b.to_ascii_lowercase())
441}
442
443async fn upgrade_connection_from_request<ReqBody>(
444    mut req: Request<ReqBody>,
445    client: impl ExtractSocketAddr,
446    shutdown: &'static ShutdownSignal,
447    websocket_config: &'static WebSocketConfig,
448) -> FaucetResult<()> {
449    // Extract sessionId query parameter
450    let query = req.uri().query().ok_or(FaucetError::BadRequest(
451        BadRequestReason::MissingQueryParam("Unable to parse query params"),
452    ))?;
453
454    let mut session_id: Option<uuid::Uuid> = None;
455    let mut attempt: Option<usize> = None;
456
457    url::form_urlencoded::parse(query.as_bytes()).for_each(|(key, value)| {
458        if case_insensitive_eq(&key, SESSION_ID_QUERY) {
459            session_id = uuid::Uuid::from_str(&value).ok();
460        } else if case_insensitive_eq(&key, "attempt") {
461            attempt = value.parse::<usize>().ok();
462        }
463    });
464
465    let session_id = session_id.ok_or(FaucetError::BadRequest(
466        BadRequestReason::MissingQueryParam("sessionId"),
467    ))?;
468
469    let attempt = attempt.ok_or(FaucetError::BadRequest(
470        BadRequestReason::MissingQueryParam("attempt"),
471    ))?;
472
473    let upgrade_info = UpgradeInfo::new(&req, client.socket_addr())?;
474    let upgraded = hyper::upgrade::on(&mut req).await?;
475    server_upgraded_io(
476        upgraded,
477        upgrade_info,
478        session_id,
479        attempt,
480        shutdown,
481        websocket_config,
482    )
483    .await?;
484    Ok(())
485}
486
487async fn init_upgrade<ReqBody: Send + Sync + 'static>(
488    req: Request<ReqBody>,
489    client: impl ExtractSocketAddr + Send + Sync + 'static,
490    shutdown: &'static ShutdownSignal,
491    websocket_config: &'static WebSocketConfig,
492) -> FaucetResult<Response<ExclusiveBody>> {
493    let mut res = Response::new(ExclusiveBody::empty());
494    let sec_websocket_key = req
495        .headers()
496        .get(SEC_WEBSOCKET_KEY)
497        .cloned()
498        .ok_or(FaucetError::no_sec_web_socket_key())?;
499    tokio::task::spawn(async move {
500        add_connection();
501        if let Err(e) =
502            upgrade_connection_from_request(req, client, shutdown, websocket_config).await
503        {
504            log::error!("upgrade error: {e:?}");
505        }
506        remove_connection();
507    });
508    *res.status_mut() = StatusCode::SWITCHING_PROTOCOLS;
509    res.headers_mut()
510        .insert(UPGRADE, HeaderValue::from_static("websocket"));
511    res.headers_mut().insert(
512        hyper::header::CONNECTION,
513        HeaderValue::from_static("Upgrade"),
514    );
515    let mut buffer = [0u8; 32];
516    res.headers_mut().insert(
517        SEC_WEBSOCKET_ACCEPT,
518        HeaderValue::from_bytes(calculate_sec_websocket_accept(
519            sec_websocket_key.as_bytes(),
520            &mut buffer,
521        ))?,
522    );
523    Ok(res)
524}
525
526#[inline(always)]
527async fn attempt_upgrade<ReqBody: Send + Sync + 'static>(
528    req: Request<ReqBody>,
529    client: impl ExtractSocketAddr + Send + Sync + 'static,
530    shutdown: &'static ShutdownSignal,
531    websocket_config: &'static WebSocketConfig,
532) -> FaucetResult<UpgradeStatus<ReqBody>> {
533    if req.headers().contains_key(UPGRADE) {
534        return Ok(UpgradeStatus::Upgraded(
535            init_upgrade(req, client, shutdown, websocket_config).await?,
536        ));
537    }
538    Ok(UpgradeStatus::NotUpgraded(req))
539}
540
541impl Client {
542    pub async fn attempt_upgrade<ReqBody>(
543        &self,
544        req: Request<ReqBody>,
545        shutdown: &'static ShutdownSignal,
546        websocket_config: &'static WebSocketConfig,
547    ) -> FaucetResult<UpgradeStatus<ReqBody>>
548    where
549        ReqBody: Send + Sync + 'static,
550    {
551        attempt_upgrade(req, self.clone(), shutdown, websocket_config).await
552    }
553}
554
555#[cfg(test)]
556mod tests {
557    use crate::{leak, networking::get_available_socket, shutdown::ShutdownSignal};
558
559    use super::*;
560    use uuid::Uuid;
561
562    #[test]
563    fn test_insensitive_compare() {
564        let session_id = "sessionid";
565        assert!(case_insensitive_eq(session_id, SESSION_ID_QUERY));
566    }
567
568    #[test]
569    fn test_calculate_sec_websocket_accept() {
570        let key = "dGhlIHNhbXBsZSBub25jZQ==";
571        let mut buffer = [0u8; 32];
572        let accept = calculate_sec_websocket_accept(key.as_bytes(), &mut buffer);
573        assert_eq!(accept, b"s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
574    }
575
576    #[test]
577    fn test_build_uri() {
578        let socket_addr = "127.0.0.1:8000".parse().unwrap();
579        let path_and_query = "/websocket".parse().unwrap();
580        let path = Some(&path_and_query);
581        let result = build_uri(socket_addr, path).unwrap();
582        assert_eq!(result, "ws://127.0.0.1:8000/websocket");
583    }
584
585    #[test]
586    fn build_uri_no_path() {
587        let socket_addr = "127.0.0.1:8000".parse().unwrap();
588        let path = None;
589        let result = build_uri(socket_addr, path).unwrap();
590        assert_eq!(result, "ws://127.0.0.1:8000");
591    }
592
593    #[tokio::test]
594    async fn test_init_upgrade_from_request() {
595        struct MockClient {
596            socket_addr: SocketAddr,
597        }
598
599        impl ExtractSocketAddr for MockClient {
600            fn socket_addr(&self) -> SocketAddr {
601                self.socket_addr
602            }
603        }
604
605        let websocket_config = leak!(WebSocketConfig::default());
606
607        let socket_addr = get_available_socket(20).await.unwrap();
608
609        let client = MockClient { socket_addr };
610
611        let server = tokio::spawn(async move {
612            dummy_websocket_server::run(socket_addr).await.unwrap();
613        });
614
615        let uri = Uri::builder()
616            .scheme("http")
617            .authority(socket_addr.to_string().as_str())
618            .path_and_query(format!("/?{}={}", SESSION_ID_QUERY, Uuid::now_v7()))
619            .build()
620            .unwrap();
621
622        let req = Request::builder()
623            .uri(uri.clone())
624            .header(UPGRADE, "websocket")
625            .header("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
626            .body(())
627            .unwrap();
628
629        let shutdown = leak!(ShutdownSignal::new());
630        let result = init_upgrade(req, client, shutdown, websocket_config)
631            .await
632            .unwrap();
633
634        server.abort();
635
636        assert_eq!(result.status(), StatusCode::SWITCHING_PROTOCOLS);
637        assert_eq!(
638            result.headers().get(UPGRADE).unwrap(),
639            HeaderValue::from_static("websocket")
640        );
641        assert_eq!(
642            result.headers().get(SEC_WEBSOCKET_ACCEPT).unwrap(),
643            HeaderValue::from_static("s3pPLMBiTxaQ9kYGzzhZRbK+xOo=")
644        );
645        assert_eq!(
646            result.headers().get(hyper::header::CONNECTION).unwrap(),
647            HeaderValue::from_static("Upgrade")
648        );
649    }
650
651    #[tokio::test]
652    async fn test_init_upgrade_from_request_no_sec_key() {
653        struct MockClient {
654            socket_addr: SocketAddr,
655        }
656
657        impl ExtractSocketAddr for MockClient {
658            fn socket_addr(&self) -> SocketAddr {
659                self.socket_addr
660            }
661        }
662
663        let websocket_config = leak!(WebSocketConfig::default());
664
665        let socket_addr = get_available_socket(20).await.unwrap();
666
667        let client = MockClient { socket_addr };
668
669        let server = tokio::spawn(async move {
670            dummy_websocket_server::run(socket_addr).await.unwrap();
671        });
672
673        let uri = Uri::builder()
674            .scheme("http")
675            .authority(socket_addr.to_string().as_str())
676            .path_and_query(format!("/?{}={}", SESSION_ID_QUERY, Uuid::now_v7()))
677            .build()
678            .unwrap();
679
680        let req = Request::builder()
681            .uri(uri.clone())
682            .header(UPGRADE, "websocket")
683            .body(())
684            .unwrap();
685
686        let shutdown = leak!(ShutdownSignal::new());
687        let result = init_upgrade(req, client, shutdown, websocket_config).await;
688
689        server.abort();
690
691        assert!(result.is_err());
692    }
693
694    #[tokio::test]
695    async fn test_attempt_upgrade_no_upgrade_header() {
696        struct MockClient {
697            socket_addr: SocketAddr,
698        }
699
700        impl ExtractSocketAddr for MockClient {
701            fn socket_addr(&self) -> SocketAddr {
702                self.socket_addr
703            }
704        }
705
706        let socket_addr = get_available_socket(20).await.unwrap();
707        let websocket_config = leak!(WebSocketConfig::default());
708
709        let client = MockClient { socket_addr };
710
711        let server = tokio::spawn(async move {
712            dummy_websocket_server::run(socket_addr).await.unwrap();
713        });
714
715        let uri = Uri::builder()
716            .scheme("http")
717            .authority(socket_addr.to_string().as_str())
718            .path_and_query("/")
719            .build()
720            .unwrap();
721
722        let req = Request::builder()
723            .uri(uri)
724            .header("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
725            .body(())
726            .unwrap();
727
728        let shutdown = leak!(ShutdownSignal::new());
729        let result = attempt_upgrade(req, client, shutdown, websocket_config)
730            .await
731            .unwrap();
732
733        server.abort();
734
735        match result {
736            UpgradeStatus::NotUpgraded(_) => {}
737            _ => panic!("Expected NotUpgraded"),
738        }
739    }
740
741    #[tokio::test]
742    async fn test_attempt_upgrade_with_upgrade_header() {
743        struct MockClient {
744            socket_addr: SocketAddr,
745        }
746
747        impl ExtractSocketAddr for MockClient {
748            fn socket_addr(&self) -> SocketAddr {
749                self.socket_addr
750            }
751        }
752
753        let websocket_config = leak!(WebSocketConfig::default());
754
755        let socket_addr = get_available_socket(20).await.unwrap();
756
757        let client = MockClient { socket_addr };
758
759        let server = tokio::spawn(async move {
760            dummy_websocket_server::run(socket_addr).await.unwrap();
761        });
762
763        let uri = Uri::builder()
764            .scheme("http")
765            .authority(socket_addr.to_string().as_str())
766            .path_and_query(format!("/?{}={}", SESSION_ID_QUERY, Uuid::now_v7()))
767            .build()
768            .unwrap();
769
770        let req = Request::builder()
771            .uri(uri)
772            .header("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
773            .header(UPGRADE, "websocket")
774            .body(())
775            .unwrap();
776
777        let shutdown = leak!(ShutdownSignal::new());
778        let result = attempt_upgrade(req, client, shutdown, websocket_config)
779            .await
780            .unwrap();
781
782        server.abort();
783
784        match result {
785            UpgradeStatus::Upgraded(res) => {
786                assert_eq!(res.status(), StatusCode::SWITCHING_PROTOCOLS);
787                assert_eq!(
788                    res.headers().get(UPGRADE).unwrap(),
789                    HeaderValue::from_static("websocket")
790                );
791                assert_eq!(
792                    res.headers().get(SEC_WEBSOCKET_ACCEPT).unwrap(),
793                    HeaderValue::from_static("s3pPLMBiTxaQ9kYGzzhZRbK+xOo=")
794                );
795                assert_eq!(
796                    res.headers().get(hyper::header::CONNECTION).unwrap(),
797                    HeaderValue::from_static("Upgrade")
798                );
799            }
800            _ => panic!("Expected Upgraded"),
801        }
802    }
803
804    mod dummy_websocket_server {
805        use std::{io::Error, net::SocketAddr};
806
807        use futures_util::{future, StreamExt, TryStreamExt};
808        use log::info;
809        use tokio::net::{TcpListener, TcpStream};
810
811        pub async fn run(addr: SocketAddr) -> Result<(), Error> {
812            // Create the event loop and TCP listener we'll accept connections on.
813            let try_socket = TcpListener::bind(&addr).await;
814            let listener = try_socket.expect("Failed to bind");
815            info!("Listening on: {addr}");
816
817            while let Ok((stream, _)) = listener.accept().await {
818                tokio::spawn(accept_connection(stream));
819            }
820
821            Ok(())
822        }
823
824        async fn accept_connection(stream: TcpStream) {
825            let addr = stream
826                .peer_addr()
827                .expect("connected streams should have a peer address");
828            info!("Peer address: {addr}");
829
830            let ws_stream = tokio_tungstenite::accept_async(stream)
831                .await
832                .expect("Error during the websocket handshake occurred");
833
834            info!("New WebSocket connection: {addr}");
835
836            let (write, read) = ws_stream.split();
837            // We should not forward messages other than text or binary.
838            read.try_filter(|msg| future::ready(msg.is_text() || msg.is_binary()))
839                .forward(write)
840                .await
841                .expect("Failed to forward messages")
842        }
843    }
844}