1use std::{
6 borrow::Cow,
7 future::{ready, Future},
8 io::Error as IoError,
9 pin::Pin,
10 task::{Context, Poll},
11};
12
13use argan_core::BoxedError;
14use base64::prelude::*;
15use fastwebsockets::{
16 FragmentCollector, Frame, OpCode, Payload, Role, WebSocket as FastWebSocket,
17 WebSocketError as FastWebSocketError,
18};
19use futures_util::FutureExt;
20use http::{
21 header::{
22 ToStrError, CONNECTION, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_PROTOCOL,
23 SEC_WEBSOCKET_VERSION, UPGRADE,
24 },
25 HeaderValue, Method,
26};
27use hyper::upgrade::{OnUpgrade, Upgraded};
28use hyper_util::rt::TokioIo;
29use sha1::{Digest, Sha1};
30
31use crate::common::header_utils::split_header_value;
32
33use super::*;
34
35const MESSAGE_SIZE_LIMIT: usize = 16 * 1024 * 1024;
39
40pub struct WebSocketUpgrade {
45 response: Response,
46 upgrade_future: UpgradeFuture,
47 some_requested_protocols: Option<HeaderValue>,
48 some_selected_protocol: Option<HeaderValue>,
49 message_size_limit: usize,
50 auto_unmasking: bool,
51 auto_sending_pong: bool,
52 auto_closing: bool,
53}
54
55impl WebSocketUpgrade {
56 fn new(
57 response: Response,
58 upgrade_future: UpgradeFuture,
59 some_requested_protocols: Option<HeaderValue>,
60 ) -> Self {
61 Self {
62 response,
63 upgrade_future,
64 some_requested_protocols,
65 some_selected_protocol: None,
66 message_size_limit: MESSAGE_SIZE_LIMIT,
67 auto_unmasking: true,
68 auto_sending_pong: true,
69 auto_closing: false,
70 }
71 }
72
73 pub fn select_protocol<Func>(
76 &mut self,
77 selector: Func,
78 ) -> Result<Option<Cow<str>>, WebSocketUpgradeError>
79 where
80 Func: Fn(&str) -> bool,
81 {
82 if let Some(requested_protocols) = self.some_requested_protocols.as_ref() {
83 let header_values = split_header_value(requested_protocols)
84 .map_err(WebSocketUpgradeError::InvalidSecWebSocketProtocol)?;
85
86 for header_value_str in header_values {
87 if selector(header_value_str) {
88 let header_value = HeaderValue::from_str(header_value_str)
89 .expect("protocol header value should be taken from a valid header value");
90
91 self.some_selected_protocol = Some(header_value);
92
93 return Ok(Some(header_value_str.into()));
94 }
95 }
96 }
97
98 Ok(None)
99 }
100
101 pub fn set_message_size_limit(&mut self, size_limit: usize) -> &mut Self {
103 self.message_size_limit = size_limit;
104
105 self
106 }
107
108 pub fn turn_off_auto_unmasking(&mut self) -> &mut Self {
110 self.auto_unmasking = false;
111
112 self
113 }
114
115 pub fn turn_off_auto_sending_pong(&mut self) -> &mut Self {
117 self.auto_sending_pong = false;
118
119 self
120 }
121
122 pub fn turn_on_auto_closing(&mut self) -> &mut Self {
124 self.auto_closing = true;
125
126 self
127 }
128
129 pub fn upgrade<Func, Fut>(self, handle_upgrade_result: Func) -> Response
132 where
133 Func: FnOnce(Result<WebSocket, WebSocketUpgradeError>) -> Fut + Send + 'static,
134 Fut: Future<Output = ()>,
135 {
136 let Self {
137 mut response,
138 upgrade_future,
139 some_requested_protocols: _,
140 some_selected_protocol,
141 message_size_limit,
142 auto_unmasking,
143 auto_sending_pong,
144 auto_closing,
145 } = self;
146
147 tokio::spawn(async move {
148 let result = upgrade_future.await.map(|mut fws| {
149 fws.set_max_message_size(message_size_limit);
150 fws.set_auto_apply_mask(auto_unmasking);
151 fws.set_auto_pong(auto_sending_pong);
152 fws.set_auto_close(auto_closing);
153
154 WebSocket(FragmentCollector::new(fws))
155 });
156
157 handle_upgrade_result(result);
158 });
159
160 if let Some(selected_protocol) = some_selected_protocol {
161 response
162 .headers_mut()
163 .insert(SEC_WEBSOCKET_PROTOCOL, selected_protocol);
164 }
165
166 response
167 }
168}
169
170impl<B> FromRequest<B> for WebSocketUpgrade {
171 type Error = WebSocketUpgradeError;
172
173 fn from_request(
174 head_parts: &mut RequestHeadParts,
175 _: B,
176 ) -> impl Future<Output = Result<Self, Self::Error>> {
177 ready(websocket_handshake(head_parts))
178 }
179}
180
181pub(crate) fn websocket_handshake(
182 head: &mut RequestHeadParts,
183) -> Result<WebSocketUpgrade, WebSocketUpgradeError> {
184 if head.method != Method::GET {
185 panic!("WebSocket is not supported with methods other than GET")
186 }
187
188 if !head
189 .headers
190 .get(CONNECTION)
191 .is_some_and(|header_value| header_value.as_bytes().eq_ignore_ascii_case(b"upgrade"))
192 {
193 return Err(WebSocketUpgradeError::InvalidConnectionHeader);
194 }
195
196 if !head
197 .headers
198 .get(UPGRADE)
199 .is_some_and(|header_value| header_value.as_bytes().eq_ignore_ascii_case(b"websocket"))
200 {
201 return Err(WebSocketUpgradeError::InvalidUpgradeHeader);
202 }
203
204 if !head
205 .headers
206 .get(SEC_WEBSOCKET_VERSION)
207 .is_some_and(|header_value| header_value.as_bytes() == b"13")
208 {
209 return Err(WebSocketUpgradeError::InvalidSecWebSocketVersion);
210 }
211
212 let Some(sec_websocket_accept_value) = head
213 .headers
214 .get(SEC_WEBSOCKET_KEY)
215 .map(|header_value| sec_websocket_accept_value_from(header_value.as_bytes()))
216 else {
217 return Err(WebSocketUpgradeError::MissingSecWebSocketKey);
218 };
219
220 let Some(upgrade_future) = head.extensions.remove::<OnUpgrade>().map(UpgradeFuture) else {
221 return Err(WebSocketUpgradeError::UnupgradableConnection);
222 };
223
224 let some_requested_protocols = head.headers.get(SEC_WEBSOCKET_PROTOCOL);
225
226 let mut response = StatusCode::SWITCHING_PROTOCOLS.into_response();
227
228 response
229 .headers_mut()
230 .insert(CONNECTION, HeaderValue::from_static("upgrade"));
231
232 response
233 .headers_mut()
234 .insert(UPGRADE, HeaderValue::from_static("websocket"));
235
236 response
237 .headers_mut()
238 .insert(SEC_WEBSOCKET_ACCEPT, sec_websocket_accept_value);
239
240 Ok(WebSocketUpgrade::new(
241 response,
242 upgrade_future,
243 some_requested_protocols.cloned(),
244 ))
245}
246
247fn sec_websocket_accept_value_from(key: &[u8]) -> HeaderValue {
248 let mut sha1 = Sha1::new();
249 sha1.update(key);
250 sha1.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
251
252 let b64 = BASE64_STANDARD.encode(sha1.finalize());
253 HeaderValue::try_from(b64).expect("base64 encoded value must be a valid header value")
254}
255
256#[derive(Debug, crate::ImplError)]
261pub enum WebSocketUpgradeError {
262 #[error("invalid Connection header")]
264 InvalidConnectionHeader,
265 #[error("invalid Upgrade header")]
267 InvalidUpgradeHeader,
268 #[error("invalid Sec-WebSocket-Version")]
270 InvalidSecWebSocketVersion,
271 #[error("missing Sec-WebSocket-Key")]
273 MissingSecWebSocketKey,
274 #[error("invlaid Sec-WebSocket-Protocol")]
276 InvalidSecWebSocketProtocol(ToStrError),
277 #[error("unupgradable connection")]
279 UnupgradableConnection,
280 #[error(transparent)]
282 Failure(#[from] hyper::Error),
283}
284
285impl IntoResponse for WebSocketUpgradeError {
286 fn into_response(self) -> Response {
287 use WebSocketUpgradeError::*;
288
289 match self {
290 InvalidConnectionHeader
291 | InvalidUpgradeHeader
292 | InvalidSecWebSocketVersion
293 | MissingSecWebSocketKey
294 | InvalidSecWebSocketProtocol(_) => StatusCode::BAD_REQUEST.into_response(),
295 _ => StatusCode::INTERNAL_SERVER_ERROR.into_response(),
296 }
297 }
298}
299
300struct UpgradeFuture(OnUpgrade);
304
305impl Future for UpgradeFuture {
306 type Output = Result<FastWebSocket<TokioIo<Upgraded>>, WebSocketUpgradeError>;
307
308 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
309 match self.0.poll_unpin(cx) {
310 Poll::Ready(result) => Poll::Ready(
311 result
312 .map(|upgraded| FastWebSocket::after_handshake(TokioIo::new(upgraded), Role::Server))
313 .map_err(WebSocketUpgradeError::Failure),
314 ),
315 Poll::Pending => Poll::Pending,
316 }
317 }
318}
319
320pub struct WebSocket(FragmentCollector<TokioIo<Upgraded>>);
325
326impl WebSocket {
327 pub async fn receive(&mut self) -> Option<Result<Message, WebSocketError>> {
331 match self.0.read_frame().await {
332 Ok(complete_frame) => match complete_frame.opcode {
333 OpCode::Text => {
334 let text = String::from_utf8(complete_frame.payload.to_vec())
336 .expect("text payload should have been guaranteed to be a valid utf-8");
337
338 Some(Ok(Message::Text(text)))
339 }
340 OpCode::Binary => Some(Ok(Message::Binary(complete_frame.payload.to_vec()))),
341 OpCode::Ping => Some(Ok(Message::Binary(complete_frame.payload.to_vec()))),
342 OpCode::Pong => Some(Ok(Message::Binary(complete_frame.payload.to_vec()))),
343 OpCode::Close => Some(Ok(Message::Close(None))),
344 OpCode::Continuation => Some(Err(WebSocketError::Unexpected(IncompleteMessage.into()))),
345 },
346 Err(error) => {
347 if let FastWebSocketError::ConnectionClosed = error {
348 return None;
349 }
350
351 Some(Err(error.into()))
352 }
353 }
354 }
355
356 pub async fn send(&mut self, message: Message) -> Result<(), WebSocketError> {
358 match message {
359 Message::Text(text) => {
360 let frame = Frame::text(Payload::Owned(text.into()));
361
362 self.0.write_frame(frame).await?
363 }
364 Message::Binary(binary) => {
365 let frame = Frame::binary(Payload::Owned(binary));
366
367 self.0.write_frame(frame).await?
368 }
369 Message::Ping(ping) => {
370 let frame = Frame::new(true, OpCode::Ping, None, Payload::Owned(ping));
371
372 self.0.write_frame(frame).await?
373 }
374 Message::Pong(pong) => {
375 let frame = Frame::pong(Payload::Owned(pong));
376
377 self.0.write_frame(frame).await?
378 }
379 Message::Close(some_close_frame) => {
380 let frame = if let Some(CloseFrame { code, reason }) = some_close_frame {
381 Frame::close(code.into(), reason.as_bytes())
382 } else {
383 Frame::close(CloseCode::_1000_Normal.into(), b"")
384 };
385
386 self.0.write_frame(frame).await?
387 }
388 };
389 Ok(())
390 }
391
392 #[inline(always)]
394 pub async fn close(mut self) -> Result<(), WebSocketError> {
395 self.send(Message::Close(None)).await
396 }
397}
398
399pub enum Message {
404 Text(String),
405 Binary(Vec<u8>),
406 Ping(Vec<u8>),
407 Pong(Vec<u8>),
408 Close(Option<CloseFrame>),
409}
410
411pub struct CloseFrame {
415 code: CloseCode,
416 reason: Cow<'static, str>,
417}
418
419#[allow(non_camel_case_types)]
424#[derive(Debug, Eq, PartialEq, Clone, Copy)]
425pub enum CloseCode {
426 _1000_Normal,
429
430 _1001_GoingAway,
433
434 _1002_ProtocolError,
437
438 _1003_UnsupportedData,
443
444 _1005_NoStatusReceived,
448
449 _1006_Abnormal,
455
456 _1007_InvalidPayloadData,
461
462 _1008_PolicyViolation,
468
469 _1009_MessageTooBig,
473
474 _1010_MandatoryExtension,
482
483 _1011_InternalError,
487
488 _1012_ServerRestart,
491
492 _1013_TryLater,
496
497 _1014_BadGateway,
500
501 _1015_TlsError,
504
505 #[doc(hidden)]
506 Unused(u16),
507 #[doc(hidden)]
508 Reserved(u16),
509 #[doc(hidden)]
510 IanaRegistered(u16),
511 #[doc(hidden)]
512 Private(u16),
513 #[doc(hidden)]
514 Bad(u16),
515}
516
517impl From<u16> for CloseCode {
518 fn from(code: u16) -> CloseCode {
519 use CloseCode::*;
520
521 match code {
522 1000 => _1000_Normal,
523 1001 => _1001_GoingAway,
524 1002 => _1002_ProtocolError,
525 1003 => _1003_UnsupportedData,
526 1005 => _1005_NoStatusReceived,
527 1006 => _1006_Abnormal,
528 1007 => _1007_InvalidPayloadData,
529 1008 => _1008_PolicyViolation,
530 1009 => _1009_MessageTooBig,
531 1010 => _1010_MandatoryExtension,
532 1011 => _1011_InternalError,
533 1012 => _1012_ServerRestart,
534 1013 => _1013_TryLater,
535 1014 => _1014_BadGateway,
536 1015 => _1015_TlsError,
537 1..=999 => Unused(code),
538 1016..=2999 => Reserved(code),
539 3000..=3999 => IanaRegistered(code),
540 4000..=4999 => Private(code),
541 _ => Bad(code),
542 }
543 }
544}
545
546impl From<CloseCode> for u16 {
547 fn from(code: CloseCode) -> u16 {
548 use CloseCode::*;
549
550 match code {
551 _1000_Normal => 1000,
552 _1001_GoingAway => 1001,
553 _1002_ProtocolError => 1002,
554 _1003_UnsupportedData => 1003,
555 _1005_NoStatusReceived => 1005,
556 _1006_Abnormal => 1006,
557 _1007_InvalidPayloadData => 1007,
558 _1008_PolicyViolation => 1008,
559 _1009_MessageTooBig => 1009,
560 _1010_MandatoryExtension => 1010,
561 _1011_InternalError => 1011,
562 _1012_ServerRestart => 1012,
563 _1013_TryLater => 1013,
564 _1014_BadGateway => 1014,
565 _1015_TlsError => 1015,
566 Unused(code) => code,
567 Reserved(code) => code,
568 IanaRegistered(code) => code,
569 Private(code) => code,
570 Bad(code) => code,
571 }
572 }
573}
574
575#[non_exhaustive]
580#[derive(Debug, crate::ImplError)]
581pub enum WebSocketError {
582 #[error("invalid fragment")]
584 InvalidFragment,
585 #[error("invalid UTF-8")]
587 InvalidUTF8,
588 #[error("invalid continuation frame")]
590 InvalidContinuationFrame,
591 #[error("invalid close frame")]
593 InvalidCloseFrame,
594 #[error("invalid close code")]
596 InvalidCloseCode,
597 #[error("unexpected EOF")]
599 UnexpectedEOF,
600 #[error("non-zero reserved bits")]
602 NonZeroReservedBits,
603 #[error("fragmented control frame")]
605 FragmentedControlFrame,
606 #[error("ping frame too large")]
608 PingFrameTooLarge,
609 #[error("message too large ")]
611 MessageTooLarge,
612 #[error("Invalid value")]
614 InvalidValue,
615 #[error(transparent)]
616 Io(#[from] IoError),
618 #[error(transparent)]
619 Unexpected(BoxedError),
621}
622
623impl From<FastWebSocketError> for WebSocketError {
624 fn from(fast_web_socket_error: FastWebSocketError) -> Self {
625 match fast_web_socket_error {
626 FastWebSocketError::InvalidFragment => Self::InvalidFragment,
627 FastWebSocketError::InvalidUTF8 => Self::InvalidUTF8,
628 FastWebSocketError::InvalidContinuationFrame => Self::InvalidContinuationFrame,
629 FastWebSocketError::InvalidCloseFrame => Self::InvalidCloseFrame,
630 FastWebSocketError::InvalidCloseCode => Self::InvalidCloseCode,
631 FastWebSocketError::UnexpectedEOF => Self::UnexpectedEOF,
632 FastWebSocketError::ReservedBitsNotZero => Self::NonZeroReservedBits,
633 FastWebSocketError::ControlFrameFragmented => Self::FragmentedControlFrame,
634 FastWebSocketError::PingFrameTooLarge => Self::PingFrameTooLarge,
635 FastWebSocketError::FrameTooLarge => Self::MessageTooLarge,
636 FastWebSocketError::InvalidValue => Self::InvalidValue,
637 FastWebSocketError::IoError(io_error) => Self::Io(io_error),
638 unexpected_error => Self::Unexpected(unexpected_error.into()),
639 }
640 }
641}
642
643#[derive(Debug, crate::ImplError)]
644#[error("incomplete message")]
645struct IncompleteMessage;
646
647