1use std::{
4 fmt,
5 ops::{Deref, DerefMut},
6};
7
8use rama_core::{
9 Service,
10 error::{ErrorContext, OpaqueError},
11 extensions::{Extensions, ExtensionsMut, ExtensionsRef},
12 matcher::Matcher,
13 rt::Executor,
14 telemetry::tracing::{self, Instrument},
15};
16#[cfg(feature = "compression")]
17use rama_http::headers::sec_websocket_extensions;
18use rama_http::{
19 Method, Request, Response, StatusCode, Version,
20 headers::{
21 self, HeaderMapExt,
22 sec_websocket_extensions::{Extension, PerMessageDeflateConfig},
23 },
24 io::upgrade,
25 proto::h2::ext::Protocol,
26 request,
27 service::web::response::{self, Headers, IntoResponse},
28};
29use rama_utils::{
30 collections::non_empty_smallvec,
31 str::{NonEmptyStr, non_empty_str},
32};
33
34use crate::{
35 Message,
36 protocol::{Role, WebSocketConfig},
37 runtime::AsyncWebSocket,
38};
39
40#[derive(Debug, Clone, Default)]
41#[non_exhaustive]
42pub struct WebSocketMatcher;
49
50impl WebSocketMatcher {
51 #[inline]
52 #[must_use]
54 pub fn new() -> Self {
55 Default::default()
56 }
57}
58
59impl<Body> Matcher<Request<Body>> for WebSocketMatcher
60where
61 Body: Send + 'static,
62{
63 fn matches(&self, _ext: Option<&mut Extensions>, req: &Request<Body>) -> bool {
64 match req.version() {
65 version @ (Version::HTTP_10 | Version::HTTP_11) => {
66 match req.method() {
67 &Method::GET => (),
68 method => {
69 tracing::debug!(
70 http.version = ?version,
71 http.request.method = %method,
72 "WebSocketMatcher: h1: unexpected method found: no match",
73 );
74 return false;
75 }
76 }
77
78 if !req
79 .headers()
80 .typed_get::<headers::Upgrade>()
81 .map(|u| u.is_websocket())
82 .unwrap_or_default()
83 {
84 tracing::trace!(
85 http.version = ?version,
86 "WebSocketMatcher: h1: no websocket upgrade header found: no match"
87 );
88 return false;
89 }
90
91 if !req
92 .headers()
93 .typed_get::<headers::Connection>()
94 .map(|c| c.contains_upgrade())
95 .unwrap_or_default()
96 {
97 tracing::trace!(
98 http.version = ?version,
99 "WebSocketMatcher: h1: no connection upgrade header found: no match",
100 );
101 return false;
102 }
103 }
104 version @ Version::HTTP_2 => {
105 match req.method() {
106 &Method::CONNECT => (),
107 method => {
108 tracing::debug!(
109 http.version = ?version,
110 http.request.method = %method,
111 "WebSocketMatcher: h2: unexpected method found: no match",
112 );
113 return false;
114 }
115 }
116
117 if !req
118 .extensions()
119 .get::<Protocol>()
120 .map(|p| p.as_str().trim().eq_ignore_ascii_case("websocket"))
121 .unwrap_or_default()
122 {
123 tracing::trace!(
124 http.version = ?version,
125 "WebSocketMatcher: h2: no websocket protocol (pseudo ext) found",
126 );
127 return false;
128 }
129 }
130 version => {
131 tracing::debug!(
132 http.version = ?version,
133 "WebSocketMatcher: unexpected http version found: no match",
134 );
135 return false;
136 }
137 }
138
139 true
140 }
141}
142
143#[derive(Debug)]
144pub enum RequestValidateError {
146 UnexpectedHttpMethod(Method),
147 UnexpectedHttpVersion(Version),
148 UnexpectedPseudoProtocolHeader(Option<Protocol>),
149 MissingUpgradeWebSocketHeader,
150 MissingConnectionUpgradeHeader,
151 InvalidSecWebSocketVersionHeader,
152 InvalidSecWebSocketKeyHeader,
153 InvalidSecWebSocketProtocolHeader(OpaqueError),
154}
155
156impl fmt::Display for RequestValidateError {
157 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
158 match self {
159 Self::UnexpectedHttpMethod(method) => {
160 write!(f, "unexpected HTTP method: {method:?}")
161 }
162 Self::UnexpectedHttpVersion(version) => {
163 write!(f, "unexpected HTTP version: {version:?}")
164 }
165 Self::UnexpectedPseudoProtocolHeader(maybe_protocol) => {
166 write!(
167 f,
168 "missing or invalid pseudo h2 protocol header: {maybe_protocol:?}"
169 )
170 }
171 Self::MissingUpgradeWebSocketHeader => {
172 write!(f, "missing upgrade WebSocket header")
173 }
174 Self::MissingConnectionUpgradeHeader => {
175 write!(f, "missing connection upgrade header")
176 }
177 Self::InvalidSecWebSocketVersionHeader => {
178 write!(f, "missing or invalid sec-websocket-version header")
179 }
180 Self::InvalidSecWebSocketKeyHeader => {
181 write!(f, "missing or invalid sec-websocket-key header")
182 }
183 Self::InvalidSecWebSocketProtocolHeader(err) => {
184 write!(f, "invalid sec-websocket-protocol header: {err}")
185 }
186 }
187 }
188}
189
190impl std::error::Error for RequestValidateError {}
191
192#[derive(Debug)]
193pub struct ClientRequestData {
194 pub accept_header: Option<headers::SecWebSocketAccept>,
195 pub protocol: Option<headers::SecWebSocketProtocol>,
196 pub extensions: Option<headers::SecWebSocketExtensions>,
197}
198
199pub fn validate_http_client_request<Body>(
200 request: &Request<Body>,
201) -> Result<ClientRequestData, RequestValidateError> {
202 tracing::trace!(
203 http.version = ?request.version(),
204 "validate http client request"
205 );
206
207 let mut accept_header = None;
208
209 match request.version() {
210 Version::HTTP_10 | Version::HTTP_11 => {
211 match request.method() {
212 &Method::GET => (),
213 method => return Err(RequestValidateError::UnexpectedHttpMethod(method.clone())),
214 }
215
216 if !request
221 .headers()
222 .typed_get::<headers::Upgrade>()
223 .map(|u| u.is_websocket())
224 .unwrap_or_default()
225 {
226 return Err(RequestValidateError::MissingUpgradeWebSocketHeader);
227 }
228
229 if !request
234 .headers()
235 .typed_get::<headers::Connection>()
236 .map(|c| c.contains_upgrade())
237 .unwrap_or_default()
238 {
239 return Err(RequestValidateError::MissingConnectionUpgradeHeader);
240 }
241
242 accept_header = match request.headers().typed_get::<headers::SecWebSocketKey>() {
249 Some(key) => headers::SecWebSocketAccept::try_from(key)
250 .inspect_err(|err| {
251 tracing::debug!(
252 "failed to create accept typed header from given key: {err}"
253 )
254 })
255 .ok(),
256 None => return Err(RequestValidateError::InvalidSecWebSocketKeyHeader),
257 };
258 }
259 Version::HTTP_2 => {
260 match request.method() {
261 &Method::CONNECT => (),
262 method => return Err(RequestValidateError::UnexpectedHttpMethod(method.clone())),
263 }
264
265 match request.extensions().get::<Protocol>() {
266 None => return Err(RequestValidateError::UnexpectedPseudoProtocolHeader(None)),
267 Some(protocol) => {
268 if !protocol.as_str().trim().eq_ignore_ascii_case("websocket") {
269 return Err(RequestValidateError::UnexpectedPseudoProtocolHeader(Some(
270 protocol.clone(),
271 )));
272 }
273 }
274 }
275 }
276 version => {
277 return Err(RequestValidateError::UnexpectedHttpVersion(version));
278 }
279 }
280
281 if request
283 .headers()
284 .typed_get::<headers::SecWebSocketVersion>()
285 .is_none()
286 {
287 return Err(RequestValidateError::InvalidSecWebSocketVersionHeader);
288 }
289
290 let protocols_header = request.headers().typed_get();
294
295 let extensions_header = request.headers().typed_get();
299
300 Ok(ClientRequestData {
301 accept_header,
302 protocol: protocols_header,
303 extensions: extensions_header,
304 })
305}
306
307#[derive(Debug, Clone, Default)]
308pub struct WebSocketAcceptor {
310 protocols: Option<headers::SecWebSocketProtocol>,
311 protocols_flex: bool,
312
313 extensions: Option<headers::SecWebSocketExtensions>,
316}
317
318impl WebSocketAcceptor {
319 #[inline]
320 #[must_use]
322 pub fn new() -> Self {
323 Default::default()
324 }
325
326 rama_utils::macros::generate_set_and_with! {
327 pub fn protocols_flex(mut self, flexible: bool) -> Self {
334 self.protocols_flex = flexible;
335 self
336 }
337 }
338
339 rama_utils::macros::generate_set_and_with! {
340 pub fn protocols(mut self, protocols: Option<headers::SecWebSocketProtocol>) -> Self {
346 self.protocols = protocols;
347 self
348 }
349 }
350
351 rama_utils::macros::generate_set_and_with! {
352 pub fn echo_protocols(mut self) -> Self {
354 self.protocols = Some(headers::SecWebSocketProtocol(non_empty_smallvec![
355 ECHO_SERVICE_SUB_PROTOCOL_DEFAULT,
356 ECHO_SERVICE_SUB_PROTOCOL_UPPER,
357 ECHO_SERVICE_SUB_PROTOCOL_LOWER,
358 ]));
359 self
360 }
361 }
362
363 rama_utils::macros::generate_set_and_with! {
364 pub fn extensions(mut self, extensions: Option<headers::SecWebSocketExtensions>) -> Self {
366 self.extensions = extensions;
367 self
368 }
369 }
370
371 #[cfg(feature = "compression")]
372 rama_utils::macros::generate_set_and_with! {
373 #[cfg_attr(docsrs, doc(cfg(feature = "compression")))]
375 pub fn per_message_deflate(mut self) -> Self {
376 self.extensions = match self.extensions.take() {
377 Some(ext) => {
378 Some(ext.with_extra_extension(Extension::PerMessageDeflate(Default::default())))
379 },
380 None => Some(headers::SecWebSocketExtensions::per_message_deflate()),
381 };
382 self
383 }
384 }
385
386 #[cfg(feature = "compression")]
387 rama_utils::macros::generate_set_and_with! {
388 #[cfg_attr(docsrs, doc(cfg(feature = "compression")))]
391 pub fn per_message_deflate_overwrite_extensions(mut self) -> Self {
392 self.extensions = Some(headers::SecWebSocketExtensions::per_message_deflate());
393 self
394 }
395 }
396
397 #[cfg(feature = "compression")]
398 rama_utils::macros::generate_set_and_with! {
399 #[cfg_attr(docsrs, doc(cfg(feature = "compression")))]
402 pub fn per_message_deflate_with_config(mut self, config: impl Into<sec_websocket_extensions::PerMessageDeflateConfig>) -> Self {
403 self.extensions = match self.extensions.take() {
404 Some(ext) => {
405 Some(ext.with_extra_extension(Extension::PerMessageDeflate(config.into())))
406 },
407 None => Some(headers::SecWebSocketExtensions::per_message_deflate_with_config(config.into())),
408 };
409 self
410 }
411 }
412
413 #[cfg(feature = "compression")]
414 rama_utils::macros::generate_set_and_with! {
415 #[cfg_attr(docsrs, doc(cfg(feature = "compression")))]
418 pub fn per_message_deflate_with_config_overwrite_extensions(mut self, config: impl Into<sec_websocket_extensions::PerMessageDeflateConfig>) -> Self {
419 self.extensions = Some(headers::SecWebSocketExtensions::per_message_deflate_with_config(config.into()));
420 self
421 }
422 }
423}
424
425impl WebSocketAcceptor {
426 pub fn into_service<S>(self, service: S) -> WebSocketAcceptorService<S> {
430 WebSocketAcceptorService {
431 acceptor: self,
432 config: None,
433 service,
434 }
435 }
436
437 #[must_use]
439 pub fn into_echo_service(mut self) -> WebSocketAcceptorService<WebSocketEchoService> {
440 if self.protocols.is_none() {
441 self.protocols_flex = true;
442 self.protocols = Some(headers::SecWebSocketProtocol(non_empty_smallvec![
443 ECHO_SERVICE_SUB_PROTOCOL_DEFAULT,
444 ECHO_SERVICE_SUB_PROTOCOL_UPPER,
445 ECHO_SERVICE_SUB_PROTOCOL_LOWER,
446 ]));
447 }
448
449 WebSocketAcceptorService {
450 acceptor: self,
451 config: None,
452 service: WebSocketEchoService::new(),
453 }
454 }
455}
456
457impl<Body> Service<Request<Body>> for WebSocketAcceptor
458where
459 Body: Send + 'static,
460{
461 type Output = (Response, Request<Body>);
462 type Error = Response;
463
464 async fn serve(&self, mut req: Request<Body>) -> Result<Self::Output, Self::Error> {
465 match validate_http_client_request(&req) {
466 Ok(request_data) => {
467 let accepted_protocol = match (
468 self.protocols_flex,
469 request_data.protocol,
470 self.protocols.as_ref(),
471 ) {
472 (false, Some(protocols), None) => {
473 tracing::debug!(
474 "WebSocketAcceptor: protocols found while none were expected: {protocols:?}"
475 );
476 return Err(StatusCode::BAD_REQUEST.into_response());
477 }
478 (false, None, Some(protocols)) => {
479 tracing::debug!(
480 "WebSocketAcceptor: no protocols found while one of following was expected: {protocols:?}"
481 );
482 return Err(StatusCode::BAD_REQUEST.into_response());
483 }
484 (_, None, None) | (true, None, Some(_)) => None,
485 (true, Some(found_protocols), None) => {
486 Some(found_protocols.accept_first_protocol())
487 }
488 (_, Some(found_protocols), Some(expected_protocols)) => {
489 if let Some(protocol) =
490 found_protocols.contains_any(expected_protocols.iter())
491 {
492 Some(protocol)
493 } else {
494 tracing::debug!(
495 "WebSocketAcceptor: no protocols from found protocol ({found_protocols:?}) matched for expected protocols: {expected_protocols:?}"
496 );
497 return Err(StatusCode::BAD_REQUEST.into_response());
498 }
499 }
500 };
501
502 let accepted_extension = match (request_data.extensions, self.extensions.as_ref()) {
503 (None, _) | (_, None) => None,
504 (Some(request_extensions), Some(allowed_extensions)) => {
505 request_extensions.0.iter().find_map(|request_ext| {
506 for allowed_ext in allowed_extensions.0.iter() {
507 if let (
508 Extension::PerMessageDeflate(request_pmd),
509 Extension::PerMessageDeflate(allowed_pmd),
510 ) = (&request_ext, allowed_ext)
511 {
512 let mut resp = PerMessageDeflateConfig {
513 identifier: allowed_pmd.identifier.clone(),
514 client_no_context_takeover: request_pmd
515 .client_no_context_takeover
516 && allowed_pmd.client_no_context_takeover,
517 server_no_context_takeover: allowed_pmd
518 .server_no_context_takeover,
519 ..Default::default()
520 };
521
522 let srv_cap = allowed_pmd.server_max_window_bits.unwrap_or(15);
525 let srv_cap = if srv_cap == 0 {
526 15
527 } else {
528 srv_cap.clamp(8, 15)
529 };
530 let cli_req_srv = request_pmd
531 .server_max_window_bits
532 .map(|v| if v == 0 { 15 } else { v.clamp(8, 15) });
533 let chosen_srv_bits = match (cli_req_srv, Some(srv_cap)) {
534 (Some(client_bits), Some(cap)) => {
535 Some(client_bits.min(cap))
536 }
537 (None, Some(cap)) => Some(cap),
538 _ => None,
539 };
540 resp.server_max_window_bits = match chosen_srv_bits {
542 Some(bits) if bits < 15 || cli_req_srv.is_some() => {
543 Some(bits)
544 }
545 _ => None,
546 };
547
548 resp.client_max_window_bits = request_pmd
551 .client_max_window_bits
552 .map(|client_bits_offer| {
553 let offer = if client_bits_offer == 0 {
554 15
555 } else {
556 client_bits_offer.clamp(8, 15)
557 };
558 let cap =
559 allowed_pmd.client_max_window_bits.unwrap_or(offer);
560 if cap == 0 {
561 offer
562 } else {
563 offer.min(cap.clamp(8, 15))
564 }
565 });
566
567 tracing::trace!(
568 "accept and use ws deflate ext w/ config: {resp:?}"
569 );
570
571 return Some(Extension::PerMessageDeflate(resp));
572 }
573 }
574 None
575 })
576 }
577 };
578
579 let protocols_header = match accepted_protocol {
580 Some(p) => {
581 tracing::trace!("inject accepted ws protocol in cfg: {p:?}");
582 req.extensions_mut().insert(p.clone());
583 Some(p.into_header())
584 }
585 None => None,
586 };
587
588 let extensions_header = match accepted_extension {
589 Some(ext) => {
590 tracing::trace!("inject accepted ws extension in cfg: {ext:?}");
591 req.extensions_mut().insert(ext.clone());
592 Some(ext.into_header())
593 }
594 None => None,
595 };
596
597 match req.version() {
598 version @ (Version::HTTP_10 | Version::HTTP_11) => {
599 let accept_header = request_data.accept_header.ok_or_else(|| {
600 tracing::debug!("WebSocketAcceptor: missing accept header (no key?)");
601 StatusCode::BAD_REQUEST.into_response()
602 })?;
603
604 let mut response = (
605 StatusCode::SWITCHING_PROTOCOLS,
606 response::Headers((
607 accept_header,
608 headers::Upgrade::websocket(),
609 headers::Connection::upgrade(),
610 )),
611 )
612 .into_response();
613 *response.version_mut() = version;
614 if let Some(protocols) = protocols_header {
615 response.headers_mut().typed_insert(protocols);
616 }
617 if let Some(extensions) = extensions_header {
618 response.headers_mut().typed_insert(extensions);
619 }
620 Ok((response, req))
621 }
622 Version::HTTP_2 => {
623 let mut response = StatusCode::OK.into_response();
624 *response.version_mut() = Version::HTTP_2;
625 if let Some(protocols) = protocols_header {
626 response.headers_mut().typed_insert(protocols);
627 }
628 if let Some(extensions) = extensions_header {
629 response.headers_mut().typed_insert(extensions);
630 }
631 Ok((response, req))
632 }
633 version => {
634 tracing::debug!(
635 http.version = ?version,
636 "WebSocketAcceptor: http client request has unexpected http version"
637 );
638 Err(StatusCode::BAD_REQUEST.into_response())
639 }
640 }
641 }
642 Err(err) => {
643 let response =
644 if matches!(err, RequestValidateError::InvalidSecWebSocketVersionHeader) {
645 (
646 Headers::single(headers::SecWebSocketVersion::V13),
647 StatusCode::BAD_REQUEST,
648 )
649 .into_response()
650 } else {
651 StatusCode::BAD_REQUEST.into_response()
652 };
653 tracing::debug!("WebSocketAcceptor: http client request failed to validate: {err}");
654 Err(response)
655 }
656 }
657 }
658}
659
660#[derive(Debug, Clone)]
665pub struct WebSocketAcceptorService<S> {
666 acceptor: WebSocketAcceptor,
667 config: Option<WebSocketConfig>,
668 service: S,
669}
670
671impl<S> WebSocketAcceptorService<S> {
672 rama_utils::macros::generate_set_and_with! {
673 pub fn config(mut self, cfg: Option<WebSocketConfig>) -> Self {
675 self.config = cfg;
676 self
677 }
678 }
679}
680
681#[derive(Debug)]
682pub struct ServerWebSocket {
688 socket: AsyncWebSocket,
689 request: request::Parts,
690}
691
692impl Deref for ServerWebSocket {
693 type Target = AsyncWebSocket;
694
695 fn deref(&self) -> &Self::Target {
696 &self.socket
697 }
698}
699
700impl DerefMut for ServerWebSocket {
701 fn deref_mut(&mut self) -> &mut Self::Target {
702 &mut self.socket
703 }
704}
705
706impl ServerWebSocket {
707 pub fn request(&self) -> &request::Parts {
709 &self.request
710 }
711
712 pub fn into_inner(self) -> AsyncWebSocket {
714 self.socket
715 }
716
717 pub fn into_parts(self) -> (AsyncWebSocket, request::Parts) {
719 (self.socket, self.request)
720 }
721}
722
723impl<S, Body> Service<Request<Body>> for WebSocketAcceptorService<S>
724where
725 S: Clone + Service<ServerWebSocket, Output = ()>,
726 Body: Send + 'static,
727{
728 type Output = Response;
729 type Error = S::Error;
730
731 async fn serve(&self, req: Request<Body>) -> Result<Self::Output, Self::Error> {
732 match self.acceptor.serve(req).await {
733 Ok((resp, req)) => {
734 #[cfg(not(feature = "compression"))]
735 if let Some(Extension::PerMessageDeflate(_)) = req.extensions().get() {
736 tracing::error!(
737 "per-message-deflate is used but compression feature is disabled. Enable it if you wish to use this extension."
738 );
739 return Ok(StatusCode::INTERNAL_SERVER_ERROR.into_response());
740 }
741
742 let handler = self.service.clone();
743 let span = tracing::trace_root_span!(
744 "ws::serve",
745 otel.kind = "server",
746 url.full = %req.uri(),
747 url.path = %req.uri().path(),
748 url.query = req.uri().query().unwrap_or_default(),
749 url.scheme = %req.uri().scheme().map(|s| s.as_str()).unwrap_or_default(),
750 network.protocol.name = "ws",
751 );
752
753 let exec = req
754 .extensions()
755 .get::<Executor>()
756 .cloned()
757 .unwrap_or_default();
758
759 exec.spawn_task(
760 async move {
761 match upgrade::handle_upgrade(&req).await {
762 Ok(upgraded) => {
763 #[cfg(feature = "compression")]
764 let maybe_ws_config = {
765 let mut ws_cfg = None;
766
767 tracing::trace!("check if pmd settings have to be applied to WS cfg...");
768
769 if let Some(Extension::PerMessageDeflate(pmd_cfg)) = req.extensions().get() {
770 tracing::trace!(
771 "apply accepted per-message-deflate cfg into WS server config: {pmd_cfg:?}"
772 );
773 ws_cfg = Some(WebSocketConfig {
774 per_message_deflate: Some(pmd_cfg.into()),
775 ..Default::default()
776 });
777 }
778
779 ws_cfg
780 };
781
782 #[cfg(not(feature = "compression"))]
783 let maybe_ws_config = None;
784
785 let socket =
786 AsyncWebSocket::from_raw_socket(upgraded, Role::Server, maybe_ws_config)
787 .await;
788
789 let (parts, _) = req.into_parts();
790
791 let server_socket = ServerWebSocket {
792 socket,
793 request: parts,
794 };
795
796 let _ = handler.serve( server_socket).await;
797 }
798 Err(e) => {
799 tracing::error!("ws upgrade error: {e:?}");
800 }
801 }
802 }
803 .instrument(span),
804 );
805 Ok(resp)
806 }
807 Err(resp) => Ok(resp),
808 }
809 }
810}
811
812const ECHO_SERVICE_SUB_PROTOCOL_DEFAULT_STR: &str = "echo";
813
814pub const ECHO_SERVICE_SUB_PROTOCOL_DEFAULT: NonEmptyStr =
816 non_empty_str!(ECHO_SERVICE_SUB_PROTOCOL_DEFAULT_STR);
817pub const ECHO_SERVICE_SUB_PROTOCOL_UPPER: NonEmptyStr = non_empty_str!("echo-upper");
819pub const ECHO_SERVICE_SUB_PROTOCOL_LOWER: NonEmptyStr = non_empty_str!("echo-lower");
821
822#[derive(Debug, Clone, Default)]
823#[non_exhaustive]
824pub struct WebSocketEchoService;
826
827impl WebSocketEchoService {
828 #[must_use]
830 pub fn new() -> Self {
831 Self
832 }
833}
834
835impl Service<AsyncWebSocket> for WebSocketEchoService {
836 type Output = ();
837 type Error = OpaqueError;
838
839 async fn serve(&self, mut socket: AsyncWebSocket) -> Result<Self::Output, Self::Error> {
840 let protocol = socket
841 .extensions()
842 .get::<headers::sec_websocket_protocol::AcceptedWebSocketProtocol>()
843 .map(|p| p.0.as_ref())
844 .unwrap_or(ECHO_SERVICE_SUB_PROTOCOL_DEFAULT_STR);
845
846 let transformer = if protocol.eq_ignore_ascii_case(&ECHO_SERVICE_SUB_PROTOCOL_LOWER) {
847 |msg: Message| match msg {
848 Message::Text(original) => Some(original.to_lowercase().into()),
849 msg @ Message::Binary(_) => Some(msg),
850 Message::Ping(_) | Message::Pong(_) | Message::Close(_) | Message::Frame(_) => None,
851 }
852 } else if protocol.eq_ignore_ascii_case(&ECHO_SERVICE_SUB_PROTOCOL_UPPER) {
853 |msg: Message| match msg {
854 Message::Text(original) => Some(original.to_uppercase().into()),
855 msg @ Message::Binary(_) => Some(msg),
856 Message::Ping(_) | Message::Pong(_) | Message::Close(_) | Message::Frame(_) => None,
857 }
858 } else {
859 |msg: Message| match msg {
860 msg @ (Message::Text(_) | Message::Binary(_)) => Some(msg),
861 Message::Ping(_) | Message::Pong(_) | Message::Close(_) | Message::Frame(_) => None,
862 }
863 };
864
865 loop {
866 let msg = socket.recv_message().await.context("recv next msg")?;
867 if let Some(msg2) = transformer(msg) {
868 socket.send_message(msg2).await.context("echo msg back")?;
869 }
870 }
871 }
872}
873
874impl Service<ServerWebSocket> for WebSocketEchoService {
875 type Output = ();
876 type Error = OpaqueError;
877
878 async fn serve(&self, socket: ServerWebSocket) -> Result<Self::Output, Self::Error> {
879 let socket = socket.into_inner();
880 self.serve(socket).await
881 }
882}
883
884impl Service<upgrade::Upgraded> for WebSocketEchoService {
885 type Output = ();
886 type Error = OpaqueError;
887
888 async fn serve(&self, io: upgrade::Upgraded) -> Result<Self::Output, Self::Error> {
889 #[cfg(not(feature = "compression"))]
890 let maybe_ws_config = {
891 if let Some(Extension::PerMessageDeflate(_)) = io.extensions().get() {
892 return Err(OpaqueError::from_display(
893 "per-message-deflate is used but compression feature is disabled. Enable it if you wish to use this extension.",
894 ));
895 }
896 None
897 };
898
899 #[cfg(feature = "compression")]
900 let maybe_ws_config = {
901 let mut ws_cfg = None;
902
903 tracing::debug!("check if pmd settings have to be applied to WS cfg...");
904
905 if let Some(Extension::PerMessageDeflate(pmd_cfg)) = io.extensions().get() {
906 tracing::debug!(
907 "apply accepted per-message-deflate cfg into WS server config: {pmd_cfg:?}"
908 );
909 ws_cfg = Some(WebSocketConfig {
910 per_message_deflate: Some(pmd_cfg.into()),
911 ..Default::default()
912 });
913 }
914
915 ws_cfg
916 };
917
918 let socket = AsyncWebSocket::from_raw_socket(io, Role::Server, maybe_ws_config).await;
919 self.serve(socket).await
920 }
921}
922
923#[cfg(test)]
924mod tests {
925 use headers::sec_websocket_protocol::AcceptedWebSocketProtocol;
926 use rama_http::Body;
927 use rama_utils::str::non_empty_str;
928
929 use super::*;
930
931 macro_rules! request {
932 (
933 $method:literal $version:literal $uri:literal
934 $(
935 $header_name:literal: $header_value:literal
936 )*
937 ) => {
938 request!(
939 $method $version $uri
940 $(
941 $header_name: $header_value
942 )*
943 w/ []
944 )
945 };
946 (
947 $method:literal $version:literal $uri:literal
948 $(
949 $header_name:literal: $header_value:literal
950 )*
951 w/ [$($extension:expr),* $(,)?]
952 ) => {
953 {
954 let req = Request::builder()
955 .uri($uri)
956 .version(match $version {
957 "HTTP/1.1" => Version::HTTP_11,
958 "HTTP/2" => Version::HTTP_2,
959 _ => unreachable!(),
960 })
961 .method(match $method {
962 "GET" => Method::GET,
963 "POST" => Method::POST,
964 "CONNECT" => Method::CONNECT,
965 _ => unreachable!(),
966 });
967
968 $(
969 let req = req.header($header_name, $header_value);
970 )*
971
972 $(
973 let req = req.extension($extension);
974 )*
975
976 req.body(Body::empty()).unwrap()
977 }
978 };
979 }
980
981 fn assert_websocket_no_match(request: &Request, matcher: &WebSocketMatcher) {
982 assert!(
983 !matcher.matches(None, request),
984 "!({matcher:?}).matches({request:?})"
985 );
986 }
987
988 fn assert_websocket_match(request: &Request, matcher: &WebSocketMatcher) {
989 assert!(
990 matcher.matches(None, request),
991 "({matcher:?}).matches({request:?})"
992 );
993 }
994
995 #[test]
996 fn test_websocket_match_default_http_11() {
997 let matcher = WebSocketMatcher::default();
998
999 assert_websocket_no_match(
1000 &request! {
1001 "GET" "HTTP/1.1" "/"
1002 },
1003 &matcher,
1004 );
1005 assert_websocket_no_match(
1006 &request! {
1007 "GET" "HTTP/1.1" "/"
1008 "Upgrade": "websocket"
1009 },
1010 &matcher,
1011 );
1012 assert_websocket_no_match(
1013 &request! {
1014 "GET" "HTTP/1.1" "/"
1015 "Connection": "upgrade"
1016 },
1017 &matcher,
1018 );
1019 assert_websocket_match(
1020 &request! {
1021 "GET" "HTTP/1.1" "/"
1022 "Connection": "upgrade"
1023 "Upgrade": "websocket"
1024 },
1025 &matcher,
1026 );
1027 }
1028
1029 #[test]
1030 fn test_websocket_match_default_http_2() {
1031 let matcher = WebSocketMatcher::default();
1032
1033 assert_websocket_no_match(
1034 &request! {
1035 "GET" "HTTP/2" "/"
1036 "Connection": "upgrade"
1037 "Upgrade": "websocket"
1038 "Sec-WebSocket-Version": "13"
1039 "Sec-WebSocket-Key": "foobar"
1040 },
1041 &matcher,
1042 );
1043 assert_websocket_match(
1044 &request! {
1045 "CONNECT" "HTTP/2" "/"
1046 w/ [
1047 Protocol::from_static("websocket"),
1048 ]
1049 },
1050 &matcher,
1051 );
1052 assert_websocket_no_match(
1053 &request! {
1054 "GET" "HTTP/2" "/"
1055 w/ [
1056 Protocol::from_static("websocket"),
1057 ]
1058 },
1059 &matcher,
1060 );
1061 }
1062
1063 async fn assert_websocket_acceptor_ok(
1064 request: Request,
1065 acceptor: &WebSocketAcceptor,
1066 expected_accepted_protocol: Option<AcceptedWebSocketProtocol>,
1067 ) {
1068 let (resp, req) = acceptor.serve(request).await.unwrap();
1069 match req.version() {
1070 Version::HTTP_10 | Version::HTTP_11 => {
1071 assert_eq!(StatusCode::SWITCHING_PROTOCOLS, resp.status())
1072 }
1073 Version::HTTP_2 => assert_eq!(StatusCode::OK, resp.status()),
1074 _ => unreachable!(),
1075 }
1076 let accepted_protocol = resp
1077 .headers()
1078 .typed_get::<headers::SecWebSocketProtocol>()
1079 .map(|p| p.accept_first_protocol());
1080 if let Some(expected_accepted_protocol) = expected_accepted_protocol {
1081 assert_eq!(
1082 accepted_protocol.as_ref(),
1083 Some(&expected_accepted_protocol),
1084 "request = {req:?}"
1085 );
1086 assert_eq!(
1087 req.extensions().get::<AcceptedWebSocketProtocol>(),
1088 Some(&expected_accepted_protocol),
1089 "request = {req:?}"
1090 );
1091 } else {
1092 assert!(accepted_protocol.is_none());
1093 assert!(
1094 req.extensions()
1095 .get::<AcceptedWebSocketProtocol>()
1096 .is_none()
1097 );
1098 }
1099 }
1100
1101 async fn assert_websocket_acceptor_bad_request(request: Request, acceptor: &WebSocketAcceptor) {
1102 let resp = acceptor.serve(request).await.unwrap_err();
1103 assert_eq!(StatusCode::BAD_REQUEST, resp.status());
1104 }
1105
1106 #[tokio::test]
1107 async fn test_websocket_acceptor_default_http_2() {
1108 let acceptor = WebSocketAcceptor::default();
1109
1110 assert_websocket_acceptor_bad_request(
1111 request! {
1112 "GET" "HTTP/2" "/"
1113 "Connection": "upgrade"
1114 "Upgrade": "websocket"
1115 "Sec-WebSocket-Version": "13"
1116 "Sec-WebSocket-Key": "foobar"
1117 },
1118 &acceptor,
1119 )
1120 .await;
1121 assert_websocket_acceptor_bad_request(
1122 request! {
1123 "CONNECT" "HTTP/2" "/"
1124 w/ [
1125 Protocol::from_static("websocket"),
1126 ]
1127 },
1128 &acceptor,
1129 )
1130 .await;
1131 assert_websocket_acceptor_bad_request(
1132 request! {
1133 "GET" "HTTP/2" "/"
1134 w/ [
1135 Protocol::from_static("websocket"),
1136 ]
1137 },
1138 &acceptor,
1139 )
1140 .await;
1141
1142 assert_websocket_acceptor_ok(
1143 request! {
1144 "CONNECT" "HTTP/2" "/"
1145 "Sec-WebSocket-Version": "13"
1146 w/ [
1147 Protocol::from_static("websocket"),
1148 ]
1149 },
1150 &acceptor,
1151 None,
1152 )
1153 .await;
1154
1155 assert_websocket_acceptor_bad_request(
1156 request! {
1157 "CONNECT" "HTTP/2" "/"
1158 "Sec-WebSocket-Version": "13"
1159 "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ=="
1160 "Sec-WebSocket-Protocol": "client"
1161 w/ [
1162 Protocol::from_static("websocket"),
1163 ]
1164 },
1165 &acceptor,
1166 )
1167 .await;
1168 }
1169
1170 #[tokio::test]
1171 async fn test_websocket_acceptor_default_http_11() {
1172 let acceptor = WebSocketAcceptor::default();
1173
1174 assert_websocket_acceptor_bad_request(
1175 request! {
1176 "GET" "HTTP/1.1" "/"
1177 "Connection": "upgrade"
1178 "Upgrade": "websocket"
1179 "Sec-WebSocket-Version": "13"
1180 "Sec-WebSocket-Key": "foobar"
1181 },
1182 &acceptor,
1183 )
1184 .await;
1185
1186 assert_websocket_acceptor_bad_request(
1187 request! {
1188 "GET" "HTTP/1.1" "/"
1189 "Connection": "upgrade"
1190 "Upgrade": "websocket"
1191 "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ=="
1192 },
1193 &acceptor,
1194 )
1195 .await;
1196
1197 assert_websocket_acceptor_bad_request(
1198 request! {
1199 "GET" "HTTP/1.1" "/"
1200 "Connection": "upgrade"
1201 "Upgrade": "websocket"
1202 "Sec-WebSocket-Version": "14"
1203 "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ=="
1204 },
1205 &acceptor,
1206 )
1207 .await;
1208
1209 assert_websocket_acceptor_bad_request(
1210 request! {
1211 "GET" "HTTP/1.1" "/"
1212 "Connection": "upgrade"
1213 "Upgrade": "foo"
1214 "Sec-WebSocket-Version": "13"
1215 "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ=="
1216 },
1217 &acceptor,
1218 )
1219 .await;
1220
1221 assert_websocket_acceptor_bad_request(
1222 request! {
1223 "GET" "HTTP/1.1" "/"
1224 "Connection": "upgrade"
1225 "Sec-WebSocket-Version": "13"
1226 "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ=="
1227 },
1228 &acceptor,
1229 )
1230 .await;
1231
1232 assert_websocket_acceptor_bad_request(
1233 request! {
1234 "GET" "HTTP/1.1" "/"
1235 "Upgrade": "websocket"
1236 "Sec-WebSocket-Version": "13"
1237 "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ=="
1238 },
1239 &acceptor,
1240 )
1241 .await;
1242
1243 assert_websocket_acceptor_bad_request(
1244 request! {
1245 "GET" "HTTP/1.1" "/"
1246 "Connection": "keep-alive"
1247 "Upgrade": "websocket"
1248 "Sec-WebSocket-Version": "13"
1249 "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ=="
1250 },
1251 &acceptor,
1252 )
1253 .await;
1254
1255 assert_websocket_acceptor_ok(
1256 request! {
1257 "GET" "HTTP/1.1" "/"
1258 "Connection": "upgrade"
1259 "Upgrade": "websocket"
1260 "Sec-WebSocket-Version": "13"
1261 "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ=="
1262 },
1263 &acceptor,
1264 None,
1265 )
1266 .await;
1267 }
1268
1269 #[tokio::test]
1270 async fn test_websocket_accept_flex_protocols() {
1271 let acceptor = WebSocketAcceptor::default().with_protocols_flex(true);
1272
1273 assert_websocket_acceptor_ok(
1276 request! {
1277 "GET" "HTTP/1.1" "/"
1278 "Connection": "upgrade"
1279 "Upgrade": "websocket"
1280 "Sec-WebSocket-Version": "13"
1281 "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ=="
1282 },
1283 &acceptor,
1284 None,
1285 )
1286 .await;
1287 assert_websocket_acceptor_ok(
1288 request! {
1289 "CONNECT" "HTTP/2" "/"
1290 "Sec-WebSocket-Version": "13"
1291 w/ [
1292 Protocol::from_static("websocket"),
1293 ]
1294 },
1295 &acceptor,
1296 None,
1297 )
1298 .await;
1299
1300 assert_websocket_acceptor_ok(
1303 request! {
1304 "GET" "HTTP/1.1" "/"
1305 "Connection": "upgrade"
1306 "Upgrade": "websocket"
1307 "Sec-WebSocket-Version": "13"
1308 "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ=="
1309 "Sec-WebSocket-Protocol": "foo"
1310 },
1311 &acceptor,
1312 Some(AcceptedWebSocketProtocol(non_empty_str!("foo"))),
1313 )
1314 .await;
1315 assert_websocket_acceptor_ok(
1316 request! {
1317 "CONNECT" "HTTP/2" "/"
1318 "Sec-WebSocket-Version": "13"
1319 "Sec-WebSocket-Protocol": "foo"
1320 w/ [
1321 Protocol::from_static("websocket"),
1322 ]
1323 },
1324 &acceptor,
1325 Some(AcceptedWebSocketProtocol(non_empty_str!("foo"))),
1326 )
1327 .await;
1328
1329 assert_websocket_acceptor_ok(
1332 request! {
1333 "GET" "HTTP/1.1" "/"
1334 "Connection": "upgrade"
1335 "Upgrade": "websocket"
1336 "Sec-WebSocket-Version": "13"
1337 "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ=="
1338 "Sec-WebSocket-Protocol": "foo, bar"
1339 },
1340 &acceptor,
1341 Some(AcceptedWebSocketProtocol(non_empty_str!("foo"))),
1342 )
1343 .await;
1344 assert_websocket_acceptor_ok(
1345 request! {
1346 "CONNECT" "HTTP/2" "/"
1347 "Sec-WebSocket-Version": "13"
1348 "Sec-WebSocket-Protocol": "foo,baz, foo"
1349 w/ [
1350 Protocol::from_static("websocket"),
1351 ]
1352 },
1353 &acceptor,
1354 Some(AcceptedWebSocketProtocol(non_empty_str!("foo"))),
1355 )
1356 .await;
1357
1358 let acceptor =
1362 acceptor.with_protocols(headers::SecWebSocketProtocol::new(non_empty_str!("foo")));
1363
1364 assert_websocket_acceptor_ok(
1365 request! {
1366 "GET" "HTTP/1.1" "/"
1367 "Connection": "upgrade"
1368 "Upgrade": "websocket"
1369 "Sec-WebSocket-Version": "13"
1370 "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ=="
1371 },
1372 &acceptor,
1373 None,
1374 )
1375 .await;
1376
1377 assert_websocket_acceptor_bad_request(
1378 request! {
1379 "CONNECT" "HTTP/2" "/"
1380 "Sec-WebSocket-Version": "13"
1381 "Sec-WebSocket-Protocol": "baz,fo"
1382 w/ [
1383 Protocol::from_static("websocket"),
1384 ]
1385 },
1386 &acceptor,
1387 )
1388 .await;
1389 }
1390
1391 #[tokio::test]
1392 async fn test_websocket_accept_required_protocols() {
1393 let acceptor = WebSocketAcceptor::default().with_protocols(headers::SecWebSocketProtocol(
1394 non_empty_smallvec![
1395 non_empty_str!("foo"),
1396 non_empty_str!("a"),
1397 non_empty_str!("b")
1398 ],
1399 ));
1400
1401 assert_websocket_acceptor_bad_request(
1404 request! {
1405 "GET" "HTTP/1.1" "/"
1406 "Connection": "upgrade"
1407 "Upgrade": "websocket"
1408 "Sec-WebSocket-Version": "13"
1409 "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ=="
1410 },
1411 &acceptor,
1412 )
1413 .await;
1414 assert_websocket_acceptor_bad_request(
1415 request! {
1416 "CONNECT" "HTTP/2" "/"
1417 "Sec-WebSocket-Version": "13"
1418 w/ [
1419 Protocol::from_static("websocket"),
1420 ]
1421 },
1422 &acceptor,
1423 )
1424 .await;
1425
1426 assert_websocket_acceptor_ok(
1429 request! {
1430 "GET" "HTTP/1.1" "/"
1431 "Connection": "upgrade"
1432 "Upgrade": "websocket"
1433 "Sec-WebSocket-Version": "13"
1434 "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ=="
1435 "Sec-WebSocket-Protocol": "foo"
1436 },
1437 &acceptor,
1438 Some(AcceptedWebSocketProtocol(non_empty_str!("foo"))),
1439 )
1440 .await;
1441 assert_websocket_acceptor_ok(
1442 request! {
1443 "CONNECT" "HTTP/2" "/"
1444 "Sec-WebSocket-Version": "13"
1445 "Sec-WebSocket-Protocol": "b"
1446 w/ [
1447 Protocol::from_static("websocket"),
1448 ]
1449 },
1450 &acceptor,
1451 Some(AcceptedWebSocketProtocol(non_empty_str!("b"))),
1452 )
1453 .await;
1454
1455 assert_websocket_acceptor_ok(
1458 request! {
1459 "GET" "HTTP/1.1" "/"
1460 "Connection": "upgrade"
1461 "Upgrade": "websocket"
1462 "Sec-WebSocket-Version": "13"
1463 "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ=="
1464 "Sec-WebSocket-Protocol": "test, b"
1465 },
1466 &acceptor,
1467 Some(AcceptedWebSocketProtocol(non_empty_str!("b"))),
1468 )
1469 .await;
1470 assert_websocket_acceptor_ok(
1471 request! {
1472 "CONNECT" "HTTP/2" "/"
1473 "Sec-WebSocket-Version": "13"
1474 "Sec-WebSocket-Protocol": "a,test, c"
1475 w/ [
1476 Protocol::from_static("websocket"),
1477 ]
1478 },
1479 &acceptor,
1480 Some(AcceptedWebSocketProtocol(non_empty_str!("a"))),
1481 )
1482 .await;
1483
1484 assert_websocket_acceptor_bad_request(
1487 request! {
1488 "GET" "HTTP/1.1" "/"
1489 "Connection": "upgrade"
1490 "Upgrade": "websocket"
1491 "Sec-WebSocket-Version": "13"
1492 "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ=="
1493 "Sec-WebSocket-Protocol": "test, c"
1494 },
1495 &acceptor,
1496 )
1497 .await;
1498 assert_websocket_acceptor_bad_request(
1499 request! {
1500 "CONNECT" "HTTP/2" "/"
1501 "Sec-WebSocket-Version": "13"
1502 "Sec-WebSocket-Protocol": "test"
1503 w/ [
1504 Protocol::from_static("websocket"),
1505 ]
1506 },
1507 &acceptor,
1508 )
1509 .await;
1510 }
1511}