argan/request/
websocket.rs

1//! Types to handle WebSocket connections.
2
3// ----------
4
5use std::{
6	borrow::Cow,
7	future::{ready, Future},
8	io::Error as IoError,
9	pin::Pin,
10	task::{Context, Poll},
11};
12
13use argan_core::BoxedError;
14use base64::prelude::*;
15use fastwebsockets::{
16	FragmentCollector, Frame, OpCode, Payload, Role, WebSocket as FastWebSocket,
17	WebSocketError as FastWebSocketError,
18};
19use futures_util::FutureExt;
20use http::{
21	header::{
22		ToStrError, CONNECTION, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY, SEC_WEBSOCKET_PROTOCOL,
23		SEC_WEBSOCKET_VERSION, UPGRADE,
24	},
25	HeaderValue, Method,
26};
27use hyper::upgrade::{OnUpgrade, Upgraded};
28use hyper_util::rt::TokioIo;
29use sha1::{Digest, Sha1};
30
31use crate::common::header_utils::split_header_value;
32
33use super::*;
34
35// --------------------------------------------------------------------------------
36// --------------------------------------------------------------------------------
37
38const MESSAGE_SIZE_LIMIT: usize = 16 * 1024 * 1024;
39
40// --------------------------------------------------
41// WebSocketUpgrade
42
43/// An extractor to establish a WebSocket connection.
44pub struct WebSocketUpgrade {
45	response: Response,
46	upgrade_future: UpgradeFuture,
47	some_requested_protocols: Option<HeaderValue>,
48	some_selected_protocol: Option<HeaderValue>,
49	message_size_limit: usize,
50	auto_unmasking: bool,
51	auto_sending_pong: bool,
52	auto_closing: bool,
53}
54
55impl WebSocketUpgrade {
56	fn new(
57		response: Response,
58		upgrade_future: UpgradeFuture,
59		some_requested_protocols: Option<HeaderValue>,
60	) -> Self {
61		Self {
62			response,
63			upgrade_future,
64			some_requested_protocols,
65			some_selected_protocol: None,
66			message_size_limit: MESSAGE_SIZE_LIMIT,
67			auto_unmasking: true,
68			auto_sending_pong: true,
69			auto_closing: false,
70		}
71	}
72
73	/// Calls the given function for each listed protocol in the `Sec-WebSocket-Protocol`
74	/// header and selects the one the given function returned true for.
75	pub fn select_protocol<Func>(
76		&mut self,
77		selector: Func,
78	) -> Result<Option<Cow<str>>, WebSocketUpgradeError>
79	where
80		Func: Fn(&str) -> bool,
81	{
82		if let Some(requested_protocols) = self.some_requested_protocols.as_ref() {
83			let header_values = split_header_value(requested_protocols)
84				.map_err(WebSocketUpgradeError::InvalidSecWebSocketProtocol)?;
85
86			for header_value_str in header_values {
87				if selector(header_value_str) {
88					let header_value = HeaderValue::from_str(header_value_str)
89						.expect("protocol header value should be taken from a valid header value");
90
91					self.some_selected_protocol = Some(header_value);
92
93					return Ok(Some(header_value_str.into()));
94				}
95			}
96		}
97
98		Ok(None)
99	}
100
101	/// Sets the maximum size limit for the message.
102	pub fn set_message_size_limit(&mut self, size_limit: usize) -> &mut Self {
103		self.message_size_limit = size_limit;
104
105		self
106	}
107
108	/// Turns off the auto unmasking the messages.
109	pub fn turn_off_auto_unmasking(&mut self) -> &mut Self {
110		self.auto_unmasking = false;
111
112		self
113	}
114
115	/// Turns off automatically sending the *pong* messages.
116	pub fn turn_off_auto_sending_pong(&mut self) -> &mut Self {
117		self.auto_sending_pong = false;
118
119		self
120	}
121
122	/// Turns on auto-responding to *close* messages.
123	pub fn turn_on_auto_closing(&mut self) -> &mut Self {
124		self.auto_closing = true;
125
126		self
127	}
128
129	/// Returns a `Response` that should be sent to the client and calls the given callback
130	/// on upgrade to handle the result.
131	pub fn upgrade<Func, Fut>(self, handle_upgrade_result: Func) -> Response
132	where
133		Func: FnOnce(Result<WebSocket, WebSocketUpgradeError>) -> Fut + Send + 'static,
134		Fut: Future<Output = ()>,
135	{
136		let Self {
137			mut response,
138			upgrade_future,
139			some_requested_protocols: _,
140			some_selected_protocol,
141			message_size_limit,
142			auto_unmasking,
143			auto_sending_pong,
144			auto_closing,
145		} = self;
146
147		tokio::spawn(async move {
148			let result = upgrade_future.await.map(|mut fws| {
149				fws.set_max_message_size(message_size_limit);
150				fws.set_auto_apply_mask(auto_unmasking);
151				fws.set_auto_pong(auto_sending_pong);
152				fws.set_auto_close(auto_closing);
153
154				WebSocket(FragmentCollector::new(fws))
155			});
156
157			handle_upgrade_result(result);
158		});
159
160		if let Some(selected_protocol) = some_selected_protocol {
161			response
162				.headers_mut()
163				.insert(SEC_WEBSOCKET_PROTOCOL, selected_protocol);
164		}
165
166		response
167	}
168}
169
170impl<B> FromRequest<B> for WebSocketUpgrade {
171	type Error = WebSocketUpgradeError;
172
173	fn from_request(
174		head_parts: &mut RequestHeadParts,
175		_: B,
176	) -> impl Future<Output = Result<Self, Self::Error>> {
177		ready(websocket_handshake(head_parts))
178	}
179}
180
181pub(crate) fn websocket_handshake(
182	head: &mut RequestHeadParts,
183) -> Result<WebSocketUpgrade, WebSocketUpgradeError> {
184	if head.method != Method::GET {
185		panic!("WebSocket is not supported with methods other than GET")
186	}
187
188	if !head
189		.headers
190		.get(CONNECTION)
191		.is_some_and(|header_value| header_value.as_bytes().eq_ignore_ascii_case(b"upgrade"))
192	{
193		return Err(WebSocketUpgradeError::InvalidConnectionHeader);
194	}
195
196	if !head
197		.headers
198		.get(UPGRADE)
199		.is_some_and(|header_value| header_value.as_bytes().eq_ignore_ascii_case(b"websocket"))
200	{
201		return Err(WebSocketUpgradeError::InvalidUpgradeHeader);
202	}
203
204	if !head
205		.headers
206		.get(SEC_WEBSOCKET_VERSION)
207		.is_some_and(|header_value| header_value.as_bytes() == b"13")
208	{
209		return Err(WebSocketUpgradeError::InvalidSecWebSocketVersion);
210	}
211
212	let Some(sec_websocket_accept_value) = head
213		.headers
214		.get(SEC_WEBSOCKET_KEY)
215		.map(|header_value| sec_websocket_accept_value_from(header_value.as_bytes()))
216	else {
217		return Err(WebSocketUpgradeError::MissingSecWebSocketKey);
218	};
219
220	let Some(upgrade_future) = head.extensions.remove::<OnUpgrade>().map(UpgradeFuture) else {
221		return Err(WebSocketUpgradeError::UnupgradableConnection);
222	};
223
224	let some_requested_protocols = head.headers.get(SEC_WEBSOCKET_PROTOCOL);
225
226	let mut response = StatusCode::SWITCHING_PROTOCOLS.into_response();
227
228	response
229		.headers_mut()
230		.insert(CONNECTION, HeaderValue::from_static("upgrade"));
231
232	response
233		.headers_mut()
234		.insert(UPGRADE, HeaderValue::from_static("websocket"));
235
236	response
237		.headers_mut()
238		.insert(SEC_WEBSOCKET_ACCEPT, sec_websocket_accept_value);
239
240	Ok(WebSocketUpgrade::new(
241		response,
242		upgrade_future,
243		some_requested_protocols.cloned(),
244	))
245}
246
247fn sec_websocket_accept_value_from(key: &[u8]) -> HeaderValue {
248	let mut sha1 = Sha1::new();
249	sha1.update(key);
250	sha1.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
251
252	let b64 = BASE64_STANDARD.encode(sha1.finalize());
253	HeaderValue::try_from(b64).expect("base64 encoded value must be a valid header value")
254}
255
256// --------------------------------------------------
257// WebSocketUpgradeError
258
259/// An error type that's returned on WebSocket upgrade failures.
260#[derive(Debug, crate::ImplError)]
261pub enum WebSocketUpgradeError {
262	/// Returned when `Connection` header is invalid.
263	#[error("invalid Connection header")]
264	InvalidConnectionHeader,
265	/// Returned when `Upgrade` header is invalid.
266	#[error("invalid Upgrade header")]
267	InvalidUpgradeHeader,
268	/// Returned when `Sec-WebSocket-Version` is not 13.
269	#[error("invalid Sec-WebSocket-Version")]
270	InvalidSecWebSocketVersion,
271	/// Returned when `Sec-WebSocket-Key` is missing.
272	#[error("missing Sec-WebSocket-Key")]
273	MissingSecWebSocketKey,
274	/// Returned on failure when converting the `Sec-WebSocket-Protocol` to a string.
275	#[error("invlaid Sec-WebSocket-Protocol")]
276	InvalidSecWebSocketProtocol(ToStrError),
277	/// Returned when the connection wasn't configured to be upgradable.
278	#[error("unupgradable connection")]
279	UnupgradableConnection,
280	/// Returned on low-level failures.
281	#[error(transparent)]
282	Failure(#[from] hyper::Error),
283}
284
285impl IntoResponse for WebSocketUpgradeError {
286	fn into_response(self) -> Response {
287		use WebSocketUpgradeError::*;
288
289		match self {
290			InvalidConnectionHeader
291			| InvalidUpgradeHeader
292			| InvalidSecWebSocketVersion
293			| MissingSecWebSocketKey
294			| InvalidSecWebSocketProtocol(_) => StatusCode::BAD_REQUEST.into_response(),
295			_ => StatusCode::INTERNAL_SERVER_ERROR.into_response(),
296		}
297	}
298}
299
300// --------------------------------------------------
301// UpgradeFuture
302
303struct UpgradeFuture(OnUpgrade);
304
305impl Future for UpgradeFuture {
306	type Output = Result<FastWebSocket<TokioIo<Upgraded>>, WebSocketUpgradeError>;
307
308	fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
309		match self.0.poll_unpin(cx) {
310			Poll::Ready(result) => Poll::Ready(
311				result
312					.map(|upgraded| FastWebSocket::after_handshake(TokioIo::new(upgraded), Role::Server))
313					.map_err(WebSocketUpgradeError::Failure),
314			),
315			Poll::Pending => Poll::Pending,
316		}
317	}
318}
319
320// --------------------------------------------------
321// WebSocket
322
323/// A successfully established WebSocket.
324pub struct WebSocket(FragmentCollector<TokioIo<Upgraded>>);
325
326impl WebSocket {
327	/// Receives a message.
328	///
329	/// Returns `None` if the connection has been closed.
330	pub async fn receive(&mut self) -> Option<Result<Message, WebSocketError>> {
331		match self.0.read_frame().await {
332			Ok(complete_frame) => match complete_frame.opcode {
333				OpCode::Text => {
334					// Price of #![forbid(unsafe_code)]
335					let text = String::from_utf8(complete_frame.payload.to_vec())
336						.expect("text payload should have been guaranteed to be a valid utf-8");
337
338					Some(Ok(Message::Text(text)))
339				}
340				OpCode::Binary => Some(Ok(Message::Binary(complete_frame.payload.to_vec()))),
341				OpCode::Ping => Some(Ok(Message::Binary(complete_frame.payload.to_vec()))),
342				OpCode::Pong => Some(Ok(Message::Binary(complete_frame.payload.to_vec()))),
343				OpCode::Close => Some(Ok(Message::Close(None))),
344				OpCode::Continuation => Some(Err(WebSocketError::Unexpected(IncompleteMessage.into()))),
345			},
346			Err(error) => {
347				if let FastWebSocketError::ConnectionClosed = error {
348					return None;
349				}
350
351				Some(Err(error.into()))
352			}
353		}
354	}
355
356	/// Sends a new message.
357	pub async fn send(&mut self, message: Message) -> Result<(), WebSocketError> {
358		match message {
359			Message::Text(text) => {
360				let frame = Frame::text(Payload::Owned(text.into()));
361
362				self.0.write_frame(frame).await?
363			}
364			Message::Binary(binary) => {
365				let frame = Frame::binary(Payload::Owned(binary));
366
367				self.0.write_frame(frame).await?
368			}
369			Message::Ping(ping) => {
370				let frame = Frame::new(true, OpCode::Ping, None, Payload::Owned(ping));
371
372				self.0.write_frame(frame).await?
373			}
374			Message::Pong(pong) => {
375				let frame = Frame::pong(Payload::Owned(pong));
376
377				self.0.write_frame(frame).await?
378			}
379			Message::Close(some_close_frame) => {
380				let frame = if let Some(CloseFrame { code, reason }) = some_close_frame {
381					Frame::close(code.into(), reason.as_bytes())
382				} else {
383					Frame::close(CloseCode::_1000_Normal.into(), b"")
384				};
385
386				self.0.write_frame(frame).await?
387			}
388		};
389		Ok(())
390	}
391
392	/// Sends a *'close frame'* to the peer and closes the connection.
393	#[inline(always)]
394	pub async fn close(mut self) -> Result<(), WebSocketError> {
395		self.send(Message::Close(None)).await
396	}
397}
398
399// --------------------------------------------------
400// Message
401
402/// A WebScoket message.
403pub enum Message {
404	Text(String),
405	Binary(Vec<u8>),
406	Ping(Vec<u8>),
407	Pong(Vec<u8>),
408	Close(Option<CloseFrame>),
409}
410
411// ----------
412
413/// A *close frame* to send when manually closing the connection.
414pub struct CloseFrame {
415	code: CloseCode,
416	reason: Cow<'static, str>,
417}
418
419// --------------------------------------------------
420// CloseCode
421
422/// A *close codes* to indicate the reason for the closure.
423#[allow(non_camel_case_types)]
424#[derive(Debug, Eq, PartialEq, Clone, Copy)]
425pub enum CloseCode {
426	/// Indicates a normal closure, meaning that the purpose for
427	/// which the connection was established has been fulfilled.
428	_1000_Normal,
429
430	/// Indicates that an endpoint is "going away", such as a server
431	/// going down or a browser having navigated away from a page.
432	_1001_GoingAway,
433
434	/// Indicates that an endpoint is terminating the connection due
435	/// to a protocol error.
436	_1002_ProtocolError,
437
438	/// Indicates that an endpoint is terminating the connection
439	/// because it has received a type of data it cannot accept (e.g., an
440	/// endpoint that understands only text data MAY send this if it
441	/// receives a binary message).
442	_1003_UnsupportedData,
443
444	/// Reserved. Indicates that no status code was included in a closing frame.
445	/// This close code makes it possible to use a single method, `on_close` to
446	/// handle even cases where no close code was provided.
447	_1005_NoStatusReceived,
448
449	/// Reserved. Indicates an abnormal closure. If the abnormal closure was due to
450	/// an error, this close code will not be used. Instead, the `on_error` method
451	/// of the handler will be called with the error. However, if the connection
452	/// is simply dropped, without an error, this close code will be sent to the
453	/// handler.
454	_1006_Abnormal,
455
456	/// Indicates that an endpoint is terminating the connection
457	/// because it has received data within a message that was not
458	/// consistent with the type of the message (e.g., non-UTF-8 \[RFC3629\]
459	/// data within a text message).
460	_1007_InvalidPayloadData,
461
462	/// Indicates that an endpoint is terminating the connection
463	/// because it has received a message that violates its policy.  This
464	/// is a generic status code that can be returned when there is no
465	/// other more suitable status code (e.g., Unsupported or Size) or if there
466	/// is a need to hide specific details about the policy.
467	_1008_PolicyViolation,
468
469	/// Indicates that an endpoint is terminating the connection
470	/// because it has received a message that is too big for it to
471	/// process.
472	_1009_MessageTooBig,
473
474	/// Indicates that an endpoint (client) is terminating the
475	/// connection because it has expected the server to negotiate one or
476	/// more extension, but the server didn't return them in the response
477	/// message of the WebSocket handshake.  The list of extensions that
478	/// are needed should be given as the reason for closing.
479	/// Note that this status code is not used by the server, because it
480	/// can fail the WebSocket handshake instead.
481	_1010_MandatoryExtension,
482
483	/// Indicates that a server is terminating the connection because
484	/// it encountered an unexpected condition that prevented it from
485	/// fulfilling the request.
486	_1011_InternalError,
487
488	/// Indicates that the server is restarting. A client may choose to reconnect,
489	/// and if it does, it should use a randomized delay of 5-30 seconds between attempts.
490	_1012_ServerRestart,
491
492	/// Indicates that the server is overloaded and the client should either connect
493	/// to a different IP (when multiple targets exist), or reconnect to the same IP
494	/// when a user has performed an action.
495	_1013_TryLater,
496
497	/// The server was acting as a gateway or proxy and received an invalid response
498	/// from the upstream server.
499	_1014_BadGateway,
500
501	/// Reserved. Indicates that the connection was closed due to a failure to perform
502	/// a TLS handshake (e.g., the server certificate can't be verified).
503	_1015_TlsError,
504
505	#[doc(hidden)]
506	Unused(u16),
507	#[doc(hidden)]
508	Reserved(u16),
509	#[doc(hidden)]
510	IanaRegistered(u16),
511	#[doc(hidden)]
512	Private(u16),
513	#[doc(hidden)]
514	Bad(u16),
515}
516
517impl From<u16> for CloseCode {
518	fn from(code: u16) -> CloseCode {
519		use CloseCode::*;
520
521		match code {
522			1000 => _1000_Normal,
523			1001 => _1001_GoingAway,
524			1002 => _1002_ProtocolError,
525			1003 => _1003_UnsupportedData,
526			1005 => _1005_NoStatusReceived,
527			1006 => _1006_Abnormal,
528			1007 => _1007_InvalidPayloadData,
529			1008 => _1008_PolicyViolation,
530			1009 => _1009_MessageTooBig,
531			1010 => _1010_MandatoryExtension,
532			1011 => _1011_InternalError,
533			1012 => _1012_ServerRestart,
534			1013 => _1013_TryLater,
535			1014 => _1014_BadGateway,
536			1015 => _1015_TlsError,
537			1..=999 => Unused(code),
538			1016..=2999 => Reserved(code),
539			3000..=3999 => IanaRegistered(code),
540			4000..=4999 => Private(code),
541			_ => Bad(code),
542		}
543	}
544}
545
546impl From<CloseCode> for u16 {
547	fn from(code: CloseCode) -> u16 {
548		use CloseCode::*;
549
550		match code {
551			_1000_Normal => 1000,
552			_1001_GoingAway => 1001,
553			_1002_ProtocolError => 1002,
554			_1003_UnsupportedData => 1003,
555			_1005_NoStatusReceived => 1005,
556			_1006_Abnormal => 1006,
557			_1007_InvalidPayloadData => 1007,
558			_1008_PolicyViolation => 1008,
559			_1009_MessageTooBig => 1009,
560			_1010_MandatoryExtension => 1010,
561			_1011_InternalError => 1011,
562			_1012_ServerRestart => 1012,
563			_1013_TryLater => 1013,
564			_1014_BadGateway => 1014,
565			_1015_TlsError => 1015,
566			Unused(code) => code,
567			Reserved(code) => code,
568			IanaRegistered(code) => code,
569			Private(code) => code,
570			Bad(code) => code,
571		}
572	}
573}
574
575// --------------------------------------------------
576// WebSocketError
577
578/// An error type that's returned on WebSocket communication failure.
579#[non_exhaustive]
580#[derive(Debug, crate::ImplError)]
581pub enum WebSocketError {
582	/// Returned when invalid frame is deteced.
583	#[error("invalid fragment")]
584	InvalidFragment,
585	/// Returned when text message has an invalid UTF-8 character.
586	#[error("invalid UTF-8")]
587	InvalidUTF8,
588	/// Returned when invalid continuation frame is deteced.
589	#[error("invalid continuation frame")]
590	InvalidContinuationFrame,
591	/// Returned when *close frame* is invalid.
592	#[error("invalid close frame")]
593	InvalidCloseFrame,
594	/// Returned when *close code* is invalid.
595	#[error("invalid close code")]
596	InvalidCloseCode,
597	/// Returned on unexpected *end of file*.
598	#[error("unexpected EOF")]
599	UnexpectedEOF,
600	/// Returned when a frame has non-zero reserved bits.
601	#[error("non-zero reserved bits")]
602	NonZeroReservedBits,
603	/// Returned when a fragmented *control frame* is detected.
604	#[error("fragmented control frame")]
605	FragmentedControlFrame,
606	/// Returned when a *ping frame* is too large.
607	#[error("ping frame too large")]
608	PingFrameTooLarge,
609	/// Returned when the received message exceeded the size limit.
610	#[error("message too large ")]
611	MessageTooLarge,
612	/// Returned on invalid value.
613	#[error("Invalid value")]
614	InvalidValue,
615	#[error(transparent)]
616	/// Returned on IO error.
617	Io(#[from] IoError),
618	#[error(transparent)]
619	/// Returned on unexpected error.
620	Unexpected(BoxedError),
621}
622
623impl From<FastWebSocketError> for WebSocketError {
624	fn from(fast_web_socket_error: FastWebSocketError) -> Self {
625		match fast_web_socket_error {
626			FastWebSocketError::InvalidFragment => Self::InvalidFragment,
627			FastWebSocketError::InvalidUTF8 => Self::InvalidUTF8,
628			FastWebSocketError::InvalidContinuationFrame => Self::InvalidContinuationFrame,
629			FastWebSocketError::InvalidCloseFrame => Self::InvalidCloseFrame,
630			FastWebSocketError::InvalidCloseCode => Self::InvalidCloseCode,
631			FastWebSocketError::UnexpectedEOF => Self::UnexpectedEOF,
632			FastWebSocketError::ReservedBitsNotZero => Self::NonZeroReservedBits,
633			FastWebSocketError::ControlFrameFragmented => Self::FragmentedControlFrame,
634			FastWebSocketError::PingFrameTooLarge => Self::PingFrameTooLarge,
635			FastWebSocketError::FrameTooLarge => Self::MessageTooLarge,
636			FastWebSocketError::InvalidValue => Self::InvalidValue,
637			FastWebSocketError::IoError(io_error) => Self::Io(io_error),
638			unexpected_error => Self::Unexpected(unexpected_error.into()),
639		}
640	}
641}
642
643#[derive(Debug, crate::ImplError)]
644#[error("incomplete message")]
645struct IncompleteMessage;
646
647// --------------------------------------------------------------------------------