faucet_server/client/
websockets.rs

1use super::{pool::ExtractSocketAddr, Client, ExclusiveBody};
2use crate::{
3    error::{FaucetError, FaucetResult},
4    global_conn::{add_connection, remove_connection},
5};
6use base64::Engine;
7use hyper::{
8    header::UPGRADE,
9    http::{uri::PathAndQuery, HeaderValue},
10    upgrade::Upgraded,
11    HeaderMap, Request, Response, StatusCode, Uri,
12};
13use hyper_util::rt::TokioIo;
14use sha1::{Digest, Sha1};
15use std::net::SocketAddr;
16
17struct UpgradeInfo {
18    headers: HeaderMap,
19    uri: Uri,
20}
21
22impl UpgradeInfo {
23    fn new<ReqBody>(req: &Request<ReqBody>, socket_addr: SocketAddr) -> FaucetResult<Self> {
24        let headers = req.headers().clone();
25        let uri = build_uri(socket_addr, req.uri().path_and_query())?;
26        Ok(Self { headers, uri })
27    }
28}
29
30const SEC_WEBSOCKET_APPEND: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
31const SEC_WEBSOCKET_KEY: &str = "Sec-WebSocket-Key";
32const SEC_WEBSOCKET_ACCEPT: &str = "Sec-WebSocket-Accept";
33
34fn calculate_sec_websocket_accept<'buffer>(key: &[u8], buffer: &'buffer mut [u8]) -> &'buffer [u8] {
35    let mut hasher = Sha1::new();
36    hasher.update(key);
37    hasher.update(SEC_WEBSOCKET_APPEND);
38    let len = base64::engine::general_purpose::STANDARD
39        .encode_slice(hasher.finalize(), buffer)
40        .expect("Should always write the internal buffer");
41    &buffer[..len]
42}
43
44fn build_uri(socket_addr: SocketAddr, path: Option<&PathAndQuery>) -> FaucetResult<Uri> {
45    let mut uri_builder = Uri::builder()
46        .scheme("ws")
47        .authority(socket_addr.to_string());
48    match path {
49        Some(path) => uri_builder = uri_builder.path_and_query(path.clone()),
50        None => uri_builder = uri_builder.path_and_query("/"),
51    }
52    Ok(uri_builder.build()?)
53}
54
55async fn server_upgraded_io(upgraded: Upgraded, mut upgrade_info: UpgradeInfo) -> FaucetResult<()> {
56    let mut upgraded = TokioIo::new(upgraded);
57    // Bridge a websocket connection to ws://localhost:3838/websocket
58    // Use tokio-tungstenite to do the websocket handshake
59    let mut request = Request::builder().uri(upgrade_info.uri).body(())?;
60    std::mem::swap(request.headers_mut(), &mut upgrade_info.headers);
61    let (mut ws_tx, _) = tokio_tungstenite::connect_async(request).await?;
62
63    // Bridge the websocket stream to the upgraded connection
64    tokio::io::copy_bidirectional(&mut upgraded, ws_tx.get_mut()).await?;
65
66    Ok(())
67}
68
69pub enum UpgradeStatus<ReqBody> {
70    Upgraded(Response<ExclusiveBody>),
71    NotUpgraded(Request<ReqBody>),
72}
73
74async fn upgrade_connection_from_request<ReqBody>(
75    mut req: Request<ReqBody>,
76    client: impl ExtractSocketAddr,
77) -> FaucetResult<()> {
78    let upgrade_info = UpgradeInfo::new(&req, client.socket_addr())?;
79    let upgraded = hyper::upgrade::on(&mut req).await?;
80    server_upgraded_io(upgraded, upgrade_info).await?;
81    Ok(())
82}
83
84async fn init_upgrade<ReqBody: Send + Sync + 'static>(
85    req: Request<ReqBody>,
86    client: impl ExtractSocketAddr + Send + Sync + 'static,
87) -> FaucetResult<Response<ExclusiveBody>> {
88    let mut res = Response::new(ExclusiveBody::empty());
89    let sec_websocket_key = req
90        .headers()
91        .get(SEC_WEBSOCKET_KEY)
92        .cloned()
93        .ok_or(FaucetError::no_sec_web_socket_key())?;
94    tokio::task::spawn(async move {
95        add_connection();
96        if let Err(e) = upgrade_connection_from_request(req, client).await {
97            log::error!("upgrade error: {:?}", e);
98        }
99        remove_connection();
100    });
101    *res.status_mut() = StatusCode::SWITCHING_PROTOCOLS;
102    res.headers_mut()
103        .insert(UPGRADE, HeaderValue::from_static("websocket"));
104    res.headers_mut().insert(
105        hyper::header::CONNECTION,
106        HeaderValue::from_static("Upgrade"),
107    );
108    let mut buffer = [0u8; 32];
109    res.headers_mut().insert(
110        SEC_WEBSOCKET_ACCEPT,
111        HeaderValue::from_bytes(calculate_sec_websocket_accept(
112            sec_websocket_key.as_bytes(),
113            &mut buffer,
114        ))?,
115    );
116    Ok(res)
117}
118
119#[inline(always)]
120async fn attempt_upgrade<ReqBody: Send + Sync + 'static>(
121    req: Request<ReqBody>,
122    client: impl ExtractSocketAddr + Send + Sync + 'static,
123) -> FaucetResult<UpgradeStatus<ReqBody>> {
124    if req.headers().contains_key(UPGRADE) {
125        return Ok(UpgradeStatus::Upgraded(init_upgrade(req, client).await?));
126    }
127    Ok(UpgradeStatus::NotUpgraded(req))
128}
129
130impl Client {
131    pub async fn attempt_upgrade<ReqBody>(
132        &self,
133        req: Request<ReqBody>,
134    ) -> FaucetResult<UpgradeStatus<ReqBody>>
135    where
136        ReqBody: Send + Sync + 'static,
137    {
138        attempt_upgrade(req, self.clone()).await
139    }
140}
141
142#[cfg(test)]
143mod tests {
144    use crate::networking::get_available_socket;
145
146    use super::*;
147
148    #[test]
149    fn test_calculate_sec_websocket_accept() {
150        let key = "dGhlIHNhbXBsZSBub25jZQ==";
151        let mut buffer = [0u8; 32];
152        let accept = calculate_sec_websocket_accept(key.as_bytes(), &mut buffer);
153        assert_eq!(accept, b"s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
154    }
155
156    #[test]
157    fn test_build_uri() {
158        let socket_addr = "127.0.0.1:8000".parse().unwrap();
159        let path_and_query = "/websocket".parse().unwrap();
160        let path = Some(&path_and_query);
161        let result = build_uri(socket_addr, path).unwrap();
162        assert_eq!(result, "ws://127.0.0.1:8000/websocket");
163    }
164
165    #[test]
166    fn build_uri_no_path() {
167        let socket_addr = "127.0.0.1:8000".parse().unwrap();
168        let path = None;
169        let result = build_uri(socket_addr, path).unwrap();
170        assert_eq!(result, "ws://127.0.0.1:8000");
171    }
172
173    #[tokio::test]
174    async fn test_init_upgrade_from_request() {
175        struct MockClient {
176            socket_addr: SocketAddr,
177        }
178
179        impl ExtractSocketAddr for MockClient {
180            fn socket_addr(&self) -> SocketAddr {
181                self.socket_addr
182            }
183        }
184
185        let socket_addr = get_available_socket(20).await.unwrap();
186
187        let client = MockClient { socket_addr };
188
189        let server = tokio::spawn(async move {
190            dummy_websocket_server::run(socket_addr).await.unwrap();
191        });
192
193        let uri = Uri::builder()
194            .scheme("http")
195            .authority(socket_addr.to_string().as_str())
196            .path_and_query("/")
197            .build()
198            .unwrap();
199
200        let req = Request::builder()
201            .uri(uri)
202            .header(UPGRADE, "websocket")
203            .header("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
204            .body(())
205            .unwrap();
206
207        let result = init_upgrade(req, client).await.unwrap();
208
209        server.abort();
210
211        assert_eq!(result.status(), StatusCode::SWITCHING_PROTOCOLS);
212        assert_eq!(
213            result.headers().get(UPGRADE).unwrap(),
214            HeaderValue::from_static("websocket")
215        );
216        assert_eq!(
217            result.headers().get(SEC_WEBSOCKET_ACCEPT).unwrap(),
218            HeaderValue::from_static("s3pPLMBiTxaQ9kYGzzhZRbK+xOo=")
219        );
220        assert_eq!(
221            result.headers().get(hyper::header::CONNECTION).unwrap(),
222            HeaderValue::from_static("Upgrade")
223        );
224    }
225
226    #[tokio::test]
227    async fn test_init_upgrade_from_request_no_sec_key() {
228        struct MockClient {
229            socket_addr: SocketAddr,
230        }
231
232        impl ExtractSocketAddr for MockClient {
233            fn socket_addr(&self) -> SocketAddr {
234                self.socket_addr
235            }
236        }
237
238        let socket_addr = get_available_socket(20).await.unwrap();
239
240        let client = MockClient { socket_addr };
241
242        let server = tokio::spawn(async move {
243            dummy_websocket_server::run(socket_addr).await.unwrap();
244        });
245
246        let uri = Uri::builder()
247            .scheme("http")
248            .authority(socket_addr.to_string().as_str())
249            .path_and_query("/")
250            .build()
251            .unwrap();
252
253        let req = Request::builder()
254            .uri(uri)
255            .header(UPGRADE, "websocket")
256            .body(())
257            .unwrap();
258
259        let result = init_upgrade(req, client).await;
260
261        server.abort();
262
263        assert!(result.is_err());
264    }
265
266    #[tokio::test]
267    async fn test_attempt_upgrade_no_upgrade_header() {
268        struct MockClient {
269            socket_addr: SocketAddr,
270        }
271
272        impl ExtractSocketAddr for MockClient {
273            fn socket_addr(&self) -> SocketAddr {
274                self.socket_addr
275            }
276        }
277
278        let socket_addr = get_available_socket(20).await.unwrap();
279
280        let client = MockClient { socket_addr };
281
282        let server = tokio::spawn(async move {
283            dummy_websocket_server::run(socket_addr).await.unwrap();
284        });
285
286        let uri = Uri::builder()
287            .scheme("http")
288            .authority(socket_addr.to_string().as_str())
289            .path_and_query("/")
290            .build()
291            .unwrap();
292
293        let req = Request::builder()
294            .uri(uri)
295            .header("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
296            .body(())
297            .unwrap();
298
299        let result = attempt_upgrade(req, client).await.unwrap();
300
301        server.abort();
302
303        match result {
304            UpgradeStatus::NotUpgraded(_) => {}
305            _ => panic!("Expected NotUpgraded"),
306        }
307    }
308
309    #[tokio::test]
310    async fn test_attempt_upgrade_with_upgrade_header() {
311        struct MockClient {
312            socket_addr: SocketAddr,
313        }
314
315        impl ExtractSocketAddr for MockClient {
316            fn socket_addr(&self) -> SocketAddr {
317                self.socket_addr
318            }
319        }
320
321        let socket_addr = get_available_socket(20).await.unwrap();
322
323        let client = MockClient { socket_addr };
324
325        let server = tokio::spawn(async move {
326            dummy_websocket_server::run(socket_addr).await.unwrap();
327        });
328
329        let uri = Uri::builder()
330            .scheme("http")
331            .authority(socket_addr.to_string().as_str())
332            .path_and_query("/")
333            .build()
334            .unwrap();
335
336        let req = Request::builder()
337            .uri(uri)
338            .header("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
339            .header(UPGRADE, "websocket")
340            .body(())
341            .unwrap();
342
343        let result = attempt_upgrade(req, client).await.unwrap();
344
345        server.abort();
346
347        match result {
348            UpgradeStatus::Upgraded(res) => {
349                assert_eq!(res.status(), StatusCode::SWITCHING_PROTOCOLS);
350                assert_eq!(
351                    res.headers().get(UPGRADE).unwrap(),
352                    HeaderValue::from_static("websocket")
353                );
354                assert_eq!(
355                    res.headers().get(SEC_WEBSOCKET_ACCEPT).unwrap(),
356                    HeaderValue::from_static("s3pPLMBiTxaQ9kYGzzhZRbK+xOo=")
357                );
358                assert_eq!(
359                    res.headers().get(hyper::header::CONNECTION).unwrap(),
360                    HeaderValue::from_static("Upgrade")
361                );
362            }
363            _ => panic!("Expected NotUpgraded"),
364        }
365    }
366
367    #[tokio::test]
368    async fn test_upgrade_connection_from_request() {
369        struct MockClient {
370            socket_addr: SocketAddr,
371        }
372
373        impl ExtractSocketAddr for MockClient {
374            fn socket_addr(&self) -> SocketAddr {
375                self.socket_addr
376            }
377        }
378
379        let socket_addr = get_available_socket(20).await.unwrap();
380
381        let client = MockClient { socket_addr };
382
383        let server = tokio::spawn(async move {
384            dummy_websocket_server::run(socket_addr).await.unwrap();
385        });
386
387        let uri = Uri::builder()
388            .scheme("http")
389            .authority(socket_addr.to_string().as_str())
390            .path_and_query("/")
391            .build()
392            .unwrap();
393
394        let req = Request::builder()
395            .uri(uri)
396            .header(UPGRADE, "websocket")
397            .header("Sec-WebSocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
398            .body(())
399            .unwrap();
400
401        let _ = tokio::spawn(async move {
402            let result = upgrade_connection_from_request(req, client).await;
403            assert!(result.is_ok());
404        })
405        .await;
406
407        server.abort();
408    }
409
410    mod dummy_websocket_server {
411        use std::{io::Error, net::SocketAddr};
412
413        use futures_util::{future, StreamExt, TryStreamExt};
414        use log::info;
415        use tokio::net::{TcpListener, TcpStream};
416
417        pub async fn run(addr: SocketAddr) -> Result<(), Error> {
418            // Create the event loop and TCP listener we'll accept connections on.
419            let try_socket = TcpListener::bind(&addr).await;
420            let listener = try_socket.expect("Failed to bind");
421            info!("Listening on: {}", addr);
422
423            while let Ok((stream, _)) = listener.accept().await {
424                tokio::spawn(accept_connection(stream));
425            }
426
427            Ok(())
428        }
429
430        async fn accept_connection(stream: TcpStream) {
431            let addr = stream
432                .peer_addr()
433                .expect("connected streams should have a peer address");
434            info!("Peer address: {}", addr);
435
436            let ws_stream = tokio_tungstenite::accept_async(stream)
437                .await
438                .expect("Error during the websocket handshake occurred");
439
440            info!("New WebSocket connection: {}", addr);
441
442            let (write, read) = ws_stream.split();
443            // We should not forward messages other than text or binary.
444            read.try_filter(|msg| future::ready(msg.is_text() || msg.is_binary()))
445                .forward(write)
446                .await
447                .expect("Failed to forward messages")
448        }
449    }
450}