rama_ws/handshake/
server.rs

1//! WebSocket server types and utilities
2
3use std::{
4    fmt,
5    ops::{Deref, DerefMut},
6};
7
8use rama_core::{
9    Service,
10    error::{ErrorContext, OpaqueError},
11    extensions::{Extensions, ExtensionsMut, ExtensionsRef},
12    matcher::Matcher,
13    rt::Executor,
14    telemetry::tracing::{self, Instrument},
15};
16#[cfg(feature = "compression")]
17use rama_http::headers::sec_websocket_extensions;
18use rama_http::{
19    Method, Request, Response, StatusCode, Version,
20    headers::{
21        self, HeaderMapExt,
22        sec_websocket_extensions::{Extension, PerMessageDeflateConfig},
23    },
24    io::upgrade,
25    proto::h2::ext::Protocol,
26    request,
27    service::web::response::{self, Headers, IntoResponse},
28};
29use rama_utils::{
30    collections::non_empty_smallvec,
31    str::{NonEmptyStr, non_empty_str},
32};
33
34use crate::{
35    Message,
36    protocol::{Role, WebSocketConfig},
37    runtime::AsyncWebSocket,
38};
39
40#[derive(Debug, Clone, Default)]
41#[non_exhaustive]
42/// WebSocket [`Matcher`] to match on incoming WebSocket requests.
43///
44/// The [`Default`] ws matcher does already out of the box the basic checks:
45///
46/// - for http/1.1: require GET method and `Upgrade: websocket` + `Connection: upgrade` headers
47/// - for h2: require CONNECT method and `:protocol: websocket` pseudo header
48pub struct WebSocketMatcher;
49
50impl WebSocketMatcher {
51    #[inline]
52    /// Create a new default [`WebSocketMatcher`].
53    #[must_use]
54    pub fn new() -> Self {
55        Default::default()
56    }
57}
58
59impl<Body> Matcher<Request<Body>> for WebSocketMatcher
60where
61    Body: Send + 'static,
62{
63    fn matches(&self, _ext: Option<&mut Extensions>, req: &Request<Body>) -> bool {
64        match req.version() {
65            version @ (Version::HTTP_10 | Version::HTTP_11) => {
66                match req.method() {
67                    &Method::GET => (),
68                    method => {
69                        tracing::debug!(
70                            http.version = ?version,
71                            http.request.method = %method,
72                            "WebSocketMatcher: h1: unexpected method found: no match",
73                        );
74                        return false;
75                    }
76                }
77
78                if !req
79                    .headers()
80                    .typed_get::<headers::Upgrade>()
81                    .map(|u| u.is_websocket())
82                    .unwrap_or_default()
83                {
84                    tracing::trace!(
85                        http.version = ?version,
86                        "WebSocketMatcher: h1: no websocket upgrade header found: no match"
87                    );
88                    return false;
89                }
90
91                if !req
92                    .headers()
93                    .typed_get::<headers::Connection>()
94                    .map(|c| c.contains_upgrade())
95                    .unwrap_or_default()
96                {
97                    tracing::trace!(
98                        http.version = ?version,
99                        "WebSocketMatcher: h1: no connection upgrade header found: no match",
100                    );
101                    return false;
102                }
103            }
104            version @ Version::HTTP_2 => {
105                match req.method() {
106                    &Method::CONNECT => (),
107                    method => {
108                        tracing::debug!(
109                            http.version = ?version,
110                            http.request.method = %method,
111                            "WebSocketMatcher: h2: unexpected method found: no match",
112                        );
113                        return false;
114                    }
115                }
116
117                if !req
118                    .extensions()
119                    .get::<Protocol>()
120                    .map(|p| p.as_str().trim().eq_ignore_ascii_case("websocket"))
121                    .unwrap_or_default()
122                {
123                    tracing::trace!(
124                        http.version = ?version,
125                        "WebSocketMatcher: h2: no websocket protocol (pseudo ext) found",
126                    );
127                    return false;
128                }
129            }
130            version => {
131                tracing::debug!(
132                    http.version = ?version,
133                    "WebSocketMatcher: unexpected http version found: no match",
134                );
135                return false;
136            }
137        }
138
139        true
140    }
141}
142
143#[derive(Debug)]
144/// Server error which can be triggered in case the request validation failed
145pub enum RequestValidateError {
146    UnexpectedHttpMethod(Method),
147    UnexpectedHttpVersion(Version),
148    UnexpectedPseudoProtocolHeader(Option<Protocol>),
149    MissingUpgradeWebSocketHeader,
150    MissingConnectionUpgradeHeader,
151    InvalidSecWebSocketVersionHeader,
152    InvalidSecWebSocketKeyHeader,
153    InvalidSecWebSocketProtocolHeader(OpaqueError),
154}
155
156impl fmt::Display for RequestValidateError {
157    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
158        match self {
159            Self::UnexpectedHttpMethod(method) => {
160                write!(f, "unexpected HTTP method: {method:?}")
161            }
162            Self::UnexpectedHttpVersion(version) => {
163                write!(f, "unexpected HTTP version: {version:?}")
164            }
165            Self::UnexpectedPseudoProtocolHeader(maybe_protocol) => {
166                write!(
167                    f,
168                    "missing or invalid pseudo h2 protocol header: {maybe_protocol:?}"
169                )
170            }
171            Self::MissingUpgradeWebSocketHeader => {
172                write!(f, "missing upgrade WebSocket header")
173            }
174            Self::MissingConnectionUpgradeHeader => {
175                write!(f, "missing connection upgrade header")
176            }
177            Self::InvalidSecWebSocketVersionHeader => {
178                write!(f, "missing or invalid sec-websocket-version header")
179            }
180            Self::InvalidSecWebSocketKeyHeader => {
181                write!(f, "missing or invalid sec-websocket-key header")
182            }
183            Self::InvalidSecWebSocketProtocolHeader(err) => {
184                write!(f, "invalid sec-websocket-protocol header: {err}")
185            }
186        }
187    }
188}
189
190impl std::error::Error for RequestValidateError {}
191
192#[derive(Debug)]
193pub struct ClientRequestData {
194    pub accept_header: Option<headers::SecWebSocketAccept>,
195    pub protocol: Option<headers::SecWebSocketProtocol>,
196    pub extensions: Option<headers::SecWebSocketExtensions>,
197}
198
199pub fn validate_http_client_request<Body>(
200    request: &Request<Body>,
201) -> Result<ClientRequestData, RequestValidateError> {
202    tracing::trace!(
203        http.version = ?request.version(),
204        "validate http client request"
205    );
206
207    let mut accept_header = None;
208
209    match request.version() {
210        Version::HTTP_10 | Version::HTTP_11 => {
211            match request.method() {
212                &Method::GET => (),
213                method => return Err(RequestValidateError::UnexpectedHttpMethod(method.clone())),
214            }
215
216            // If the request lacks an |Upgrade| header field or the |Upgrade|
217            // header field contains a value that is not an ASCII case-
218            // insensitive match for the value "websocket", the server MUST
219            // _Fail the WebSocket Connection_. (RFC 6455)
220            if !request
221                .headers()
222                .typed_get::<headers::Upgrade>()
223                .map(|u| u.is_websocket())
224                .unwrap_or_default()
225            {
226                return Err(RequestValidateError::MissingUpgradeWebSocketHeader);
227            }
228
229            // If the request lacks a |Connection| header field or the
230            // |Connection| header field doesn't contain a token that is an
231            // ASCII case-insensitive match for the value "Upgrade", the server
232            // MUST _Fail the WebSocket Connection_. (RFC 6455)
233            if !request
234                .headers()
235                .typed_get::<headers::Connection>()
236                .map(|c| c.contains_upgrade())
237                .unwrap_or_default()
238            {
239                return Err(RequestValidateError::MissingConnectionUpgradeHeader);
240            }
241
242            // A |Sec-WebSocket-Key| header field with a base64-encoded (see
243            // Section 4 of [RFC4648]) value that, when decoded, is 16 bytes in
244            // length.
245            //
246            // Only used for http/1.1 style WebSocket upgrade, not h2
247            // as in the latter it is deprecated by the `:protocol` PSEUDO header.
248            accept_header = match request.headers().typed_get::<headers::SecWebSocketKey>() {
249                Some(key) => headers::SecWebSocketAccept::try_from(key)
250                    .inspect_err(|err| {
251                        tracing::debug!(
252                            "failed to create accept typed header from given key: {err}"
253                        )
254                    })
255                    .ok(),
256                None => return Err(RequestValidateError::InvalidSecWebSocketKeyHeader),
257            };
258        }
259        Version::HTTP_2 => {
260            match request.method() {
261                &Method::CONNECT => (),
262                method => return Err(RequestValidateError::UnexpectedHttpMethod(method.clone())),
263            }
264
265            match request.extensions().get::<Protocol>() {
266                None => return Err(RequestValidateError::UnexpectedPseudoProtocolHeader(None)),
267                Some(protocol) => {
268                    if !protocol.as_str().trim().eq_ignore_ascii_case("websocket") {
269                        return Err(RequestValidateError::UnexpectedPseudoProtocolHeader(Some(
270                            protocol.clone(),
271                        )));
272                    }
273                }
274            }
275        }
276        version => {
277            return Err(RequestValidateError::UnexpectedHttpVersion(version));
278        }
279    }
280
281    // A |Sec-WebSocket-Version| header field, with a value of 13.
282    if request
283        .headers()
284        .typed_get::<headers::SecWebSocketVersion>()
285        .is_none()
286    {
287        return Err(RequestValidateError::InvalidSecWebSocketVersionHeader);
288    }
289
290    // Optionally, a |Sec-WebSocket-Protocol| header field, with a list
291    // of values indicating which protocols the client would like to
292    // speak, ordered by preference.
293    let protocols_header = request.headers().typed_get();
294
295    // Also optionally, a |Sec-WebSocket-Extensions| header field, with a list
296    // of values indicating which extensions the client would like to
297    // utilise, ordered by preference.
298    let extensions_header = request.headers().typed_get();
299
300    Ok(ClientRequestData {
301        accept_header,
302        protocol: protocols_header,
303        extensions: extensions_header,
304    })
305}
306
307#[derive(Debug, Clone, Default)]
308/// An acceptor that can be used for upgrades os WebSockets on the server side.
309pub struct WebSocketAcceptor {
310    protocols: Option<headers::SecWebSocketProtocol>,
311    protocols_flex: bool,
312
313    // extensions are always flexible in context of what both
314    // client and server support... as such... extensions *_*
315    extensions: Option<headers::SecWebSocketExtensions>,
316}
317
318impl WebSocketAcceptor {
319    #[inline]
320    /// Create a new default [`WebSocketAcceptor`].
321    #[must_use]
322    pub fn new() -> Self {
323        Default::default()
324    }
325
326    rama_utils::macros::generate_set_and_with! {
327        /// Define if the protocols validation and actioning is flexible.
328        ///
329        /// - In case no protocols are defined by server it implies that
330        ///   the server will accept any incoming protocol instead of denying protocols.
331        /// - Or in case server did specify a protocol allow list it will also
332        ///   accept incoming requests which do not define a protocol.
333        pub fn protocols_flex(mut self, flexible: bool) -> Self {
334            self.protocols_flex = flexible;
335            self
336        }
337    }
338
339    rama_utils::macros::generate_set_and_with! {
340        /// Define the WebSocket protocols.
341        ///
342        /// The protocols defined by the server (matcher) act as an allow list.
343        /// You can make protocols optional in case you also wish to allow no
344        /// protocols to be defined by marking protocols as flexible.
345        pub fn protocols(mut self, protocols: Option<headers::SecWebSocketProtocol>) -> Self {
346            self.protocols = protocols;
347            self
348        }
349    }
350
351    rama_utils::macros::generate_set_and_with! {
352        /// Define the WebSocket rama echo protocols.
353        pub fn echo_protocols(mut self) -> Self {
354            self.protocols = Some(headers::SecWebSocketProtocol(non_empty_smallvec![
355                ECHO_SERVICE_SUB_PROTOCOL_DEFAULT,
356                    ECHO_SERVICE_SUB_PROTOCOL_UPPER,
357                    ECHO_SERVICE_SUB_PROTOCOL_LOWER,
358                ]));
359            self
360        }
361    }
362
363    rama_utils::macros::generate_set_and_with! {
364        /// Define the WebSocket extensions to be supported by the server.
365        pub fn extensions(mut self, extensions: Option<headers::SecWebSocketExtensions>) -> Self {
366            self.extensions = extensions;
367            self
368        }
369    }
370
371    #[cfg(feature = "compression")]
372    rama_utils::macros::generate_set_and_with! {
373        /// Set or add the deflate WebSocket extension with the default config
374        #[cfg_attr(docsrs, doc(cfg(feature = "compression")))]
375        pub fn per_message_deflate(mut self) -> Self {
376            self.extensions = match self.extensions.take() {
377                Some(ext) => {
378                    Some(ext.with_extra_extension(Extension::PerMessageDeflate(Default::default())))
379                },
380                None => Some(headers::SecWebSocketExtensions::per_message_deflate()),
381            };
382            self
383        }
384    }
385
386    #[cfg(feature = "compression")]
387    rama_utils::macros::generate_set_and_with! {
388        /// Set the deflate WebSocket extension with the default config,
389        /// erasing existing if it already exists.
390        #[cfg_attr(docsrs, doc(cfg(feature = "compression")))]
391        pub fn per_message_deflate_overwrite_extensions(mut self) -> Self {
392            self.extensions = Some(headers::SecWebSocketExtensions::per_message_deflate());
393            self
394        }
395    }
396
397    #[cfg(feature = "compression")]
398    rama_utils::macros::generate_set_and_with! {
399        /// Set or add the deflate WebSocket extension with the given config,
400        /// erasing existing if it already exists.
401        #[cfg_attr(docsrs, doc(cfg(feature = "compression")))]
402        pub fn per_message_deflate_with_config(mut self, config: impl Into<sec_websocket_extensions::PerMessageDeflateConfig>) -> Self {
403            self.extensions = match self.extensions.take() {
404                Some(ext) => {
405                    Some(ext.with_extra_extension(Extension::PerMessageDeflate(config.into())))
406                },
407                None => Some(headers::SecWebSocketExtensions::per_message_deflate_with_config(config.into())),
408            };
409            self
410        }
411    }
412
413    #[cfg(feature = "compression")]
414    rama_utils::macros::generate_set_and_with! {
415        /// Set or add the deflate WebSocket extension with the given config,
416        /// erasing existing if it already exists.
417        #[cfg_attr(docsrs, doc(cfg(feature = "compression")))]
418        pub fn per_message_deflate_with_config_overwrite_extensions(mut self, config: impl Into<sec_websocket_extensions::PerMessageDeflateConfig>) -> Self {
419            self.extensions = Some(headers::SecWebSocketExtensions::per_message_deflate_with_config(config.into()));
420            self
421        }
422    }
423}
424
425impl WebSocketAcceptor {
426    /// Consume `self` into an [`WebSocketAcceptorService`] ready to serve.
427    ///
428    /// Use the `UpgradeLayer` in case the ws upgrade is optional.
429    pub fn into_service<S>(self, service: S) -> WebSocketAcceptorService<S> {
430        WebSocketAcceptorService {
431            acceptor: self,
432            config: None,
433            service,
434        }
435    }
436
437    /// Turn this [`WebSocketAcceptor`] into an echo [`WebSocketAcceptorService`]].
438    #[must_use]
439    pub fn into_echo_service(mut self) -> WebSocketAcceptorService<WebSocketEchoService> {
440        if self.protocols.is_none() {
441            self.protocols_flex = true;
442            self.protocols = Some(headers::SecWebSocketProtocol(non_empty_smallvec![
443                ECHO_SERVICE_SUB_PROTOCOL_DEFAULT,
444                ECHO_SERVICE_SUB_PROTOCOL_UPPER,
445                ECHO_SERVICE_SUB_PROTOCOL_LOWER,
446            ]));
447        }
448
449        WebSocketAcceptorService {
450            acceptor: self,
451            config: None,
452            service: WebSocketEchoService::new(),
453        }
454    }
455}
456
457impl<Body> Service<Request<Body>> for WebSocketAcceptor
458where
459    Body: Send + 'static,
460{
461    type Output = (Response, Request<Body>);
462    type Error = Response;
463
464    async fn serve(&self, mut req: Request<Body>) -> Result<Self::Output, Self::Error> {
465        match validate_http_client_request(&req) {
466            Ok(request_data) => {
467                let accepted_protocol = match (
468                    self.protocols_flex,
469                    request_data.protocol,
470                    self.protocols.as_ref(),
471                ) {
472                    (false, Some(protocols), None) => {
473                        tracing::debug!(
474                            "WebSocketAcceptor: protocols found while none were expected: {protocols:?}"
475                        );
476                        return Err(StatusCode::BAD_REQUEST.into_response());
477                    }
478                    (false, None, Some(protocols)) => {
479                        tracing::debug!(
480                            "WebSocketAcceptor: no protocols found while one of following was expected: {protocols:?}"
481                        );
482                        return Err(StatusCode::BAD_REQUEST.into_response());
483                    }
484                    (_, None, None) | (true, None, Some(_)) => None,
485                    (true, Some(found_protocols), None) => {
486                        Some(found_protocols.accept_first_protocol())
487                    }
488                    (_, Some(found_protocols), Some(expected_protocols)) => {
489                        if let Some(protocol) =
490                            found_protocols.contains_any(expected_protocols.iter())
491                        {
492                            Some(protocol)
493                        } else {
494                            tracing::debug!(
495                                "WebSocketAcceptor: no protocols from found protocol ({found_protocols:?}) matched for expected protocols: {expected_protocols:?}"
496                            );
497                            return Err(StatusCode::BAD_REQUEST.into_response());
498                        }
499                    }
500                };
501
502                let accepted_extension = match (request_data.extensions, self.extensions.as_ref()) {
503                    (None, _) | (_, None) => None,
504                    (Some(request_extensions), Some(allowed_extensions)) => {
505                        request_extensions.0.iter().find_map(|request_ext| {
506                            for allowed_ext in allowed_extensions.0.iter() {
507                                if let (
508                                    Extension::PerMessageDeflate(request_pmd),
509                                    Extension::PerMessageDeflate(allowed_pmd),
510                                ) = (&request_ext, allowed_ext)
511                                {
512                                    let mut resp = PerMessageDeflateConfig {
513                                        identifier: allowed_pmd.identifier.clone(),
514                                        client_no_context_takeover: request_pmd
515                                            .client_no_context_takeover
516                                            && allowed_pmd.client_no_context_takeover,
517                                        server_no_context_takeover: allowed_pmd
518                                            .server_no_context_takeover,
519                                        ..Default::default()
520                                    };
521
522                                    // server_max_window_bits
523                                    // server may include this even if client did not offer it
524                                    let srv_cap = allowed_pmd.server_max_window_bits.unwrap_or(15);
525                                    let srv_cap = if srv_cap == 0 {
526                                        15
527                                    } else {
528                                        srv_cap.clamp(8, 15)
529                                    };
530                                    let cli_req_srv = request_pmd
531                                        .server_max_window_bits
532                                        .map(|v| if v == 0 { 15 } else { v.clamp(8, 15) });
533                                    let chosen_srv_bits = match (cli_req_srv, Some(srv_cap)) {
534                                        (Some(client_bits), Some(cap)) => {
535                                            Some(client_bits.min(cap))
536                                        }
537                                        (None, Some(cap)) => Some(cap),
538                                        _ => None,
539                                    };
540                                    // include only if it actually constrains or was explicitly discussed
541                                    resp.server_max_window_bits = match chosen_srv_bits {
542                                        Some(bits) if bits < 15 || cli_req_srv.is_some() => {
543                                            Some(bits)
544                                        }
545                                        _ => None,
546                                    };
547
548                                    // client_max_window_bits
549                                    // server must not include unless client offered it
550                                    resp.client_max_window_bits = request_pmd
551                                        .client_max_window_bits
552                                        .map(|client_bits_offer| {
553                                            let offer = if client_bits_offer == 0 {
554                                                15
555                                            } else {
556                                                client_bits_offer.clamp(8, 15)
557                                            };
558                                            let cap =
559                                                allowed_pmd.client_max_window_bits.unwrap_or(offer);
560                                            if cap == 0 {
561                                                offer
562                                            } else {
563                                                offer.min(cap.clamp(8, 15))
564                                            }
565                                        });
566
567                                    tracing::trace!(
568                                        "accept and use ws deflate ext w/ config: {resp:?}"
569                                    );
570
571                                    return Some(Extension::PerMessageDeflate(resp));
572                                }
573                            }
574                            None
575                        })
576                    }
577                };
578
579                let protocols_header = match accepted_protocol {
580                    Some(p) => {
581                        tracing::trace!("inject accepted ws protocol in cfg: {p:?}");
582                        req.extensions_mut().insert(p.clone());
583                        Some(p.into_header())
584                    }
585                    None => None,
586                };
587
588                let extensions_header = match accepted_extension {
589                    Some(ext) => {
590                        tracing::trace!("inject accepted ws extension in cfg: {ext:?}");
591                        req.extensions_mut().insert(ext.clone());
592                        Some(ext.into_header())
593                    }
594                    None => None,
595                };
596
597                match req.version() {
598                    version @ (Version::HTTP_10 | Version::HTTP_11) => {
599                        let accept_header = request_data.accept_header.ok_or_else(|| {
600                            tracing::debug!("WebSocketAcceptor: missing accept header (no key?)");
601                            StatusCode::BAD_REQUEST.into_response()
602                        })?;
603
604                        let mut response = (
605                            StatusCode::SWITCHING_PROTOCOLS,
606                            response::Headers((
607                                accept_header,
608                                headers::Upgrade::websocket(),
609                                headers::Connection::upgrade(),
610                            )),
611                        )
612                            .into_response();
613                        *response.version_mut() = version;
614                        if let Some(protocols) = protocols_header {
615                            response.headers_mut().typed_insert(protocols);
616                        }
617                        if let Some(extensions) = extensions_header {
618                            response.headers_mut().typed_insert(extensions);
619                        }
620                        Ok((response, req))
621                    }
622                    Version::HTTP_2 => {
623                        let mut response = StatusCode::OK.into_response();
624                        *response.version_mut() = Version::HTTP_2;
625                        if let Some(protocols) = protocols_header {
626                            response.headers_mut().typed_insert(protocols);
627                        }
628                        if let Some(extensions) = extensions_header {
629                            response.headers_mut().typed_insert(extensions);
630                        }
631                        Ok((response, req))
632                    }
633                    version => {
634                        tracing::debug!(
635                            http.version = ?version,
636                            "WebSocketAcceptor: http client request has unexpected http version"
637                        );
638                        Err(StatusCode::BAD_REQUEST.into_response())
639                    }
640                }
641            }
642            Err(err) => {
643                let response =
644                    if matches!(err, RequestValidateError::InvalidSecWebSocketVersionHeader) {
645                        (
646                            Headers::single(headers::SecWebSocketVersion::V13),
647                            StatusCode::BAD_REQUEST,
648                        )
649                            .into_response()
650                    } else {
651                        StatusCode::BAD_REQUEST.into_response()
652                    };
653                tracing::debug!("WebSocketAcceptor: http client request failed to validate: {err}");
654                Err(response)
655            }
656        }
657    }
658}
659
660/// Shortcut that can be used for endpoint WS services.
661///
662/// Created via [`WebSocketAcceptor::into_service`]
663/// or `WebSocketAcceptor::into_echo_service`].
664#[derive(Debug, Clone)]
665pub struct WebSocketAcceptorService<S> {
666    acceptor: WebSocketAcceptor,
667    config: Option<WebSocketConfig>,
668    service: S,
669}
670
671impl<S> WebSocketAcceptorService<S> {
672    rama_utils::macros::generate_set_and_with! {
673        /// Set the [`WebSocketConfig`], overwriting the previous config if already set.
674        pub fn config(mut self, cfg: Option<WebSocketConfig>) -> Self {
675            self.config = cfg;
676            self
677        }
678    }
679}
680
681#[derive(Debug)]
682/// Server WebSocket, used as input-output stream.
683///
684/// Utility type created via [`WebSocketAcceptorService`].
685///
686/// [`AcceptedSubProtocol`] can be found in the [`Context`], if any.
687pub struct ServerWebSocket {
688    socket: AsyncWebSocket,
689    request: request::Parts,
690}
691
692impl Deref for ServerWebSocket {
693    type Target = AsyncWebSocket;
694
695    fn deref(&self) -> &Self::Target {
696        &self.socket
697    }
698}
699
700impl DerefMut for ServerWebSocket {
701    fn deref_mut(&mut self) -> &mut Self::Target {
702        &mut self.socket
703    }
704}
705
706impl ServerWebSocket {
707    /// View the original request data, from which this server web socket was created.
708    pub fn request(&self) -> &request::Parts {
709        &self.request
710    }
711
712    /// Consume `self` as an [`AsyncWebSocket].
713    pub fn into_inner(self) -> AsyncWebSocket {
714        self.socket
715    }
716
717    /// Consume `self` into its parts.
718    pub fn into_parts(self) -> (AsyncWebSocket, request::Parts) {
719        (self.socket, self.request)
720    }
721}
722
723impl<S, Body> Service<Request<Body>> for WebSocketAcceptorService<S>
724where
725    S: Clone + Service<ServerWebSocket, Output = ()>,
726    Body: Send + 'static,
727{
728    type Output = Response;
729    type Error = S::Error;
730
731    async fn serve(&self, req: Request<Body>) -> Result<Self::Output, Self::Error> {
732        match self.acceptor.serve(req).await {
733            Ok((resp, req)) => {
734                #[cfg(not(feature = "compression"))]
735                if let Some(Extension::PerMessageDeflate(_)) = req.extensions().get() {
736                    tracing::error!(
737                        "per-message-deflate is used but compression feature is disabled. Enable it if you wish to use this extension."
738                    );
739                    return Ok(StatusCode::INTERNAL_SERVER_ERROR.into_response());
740                }
741
742                let handler = self.service.clone();
743                let span = tracing::trace_root_span!(
744                    "ws::serve",
745                    otel.kind = "server",
746                    url.full = %req.uri(),
747                    url.path = %req.uri().path(),
748                    url.query = req.uri().query().unwrap_or_default(),
749                    url.scheme = %req.uri().scheme().map(|s| s.as_str()).unwrap_or_default(),
750                    network.protocol.name = "ws",
751                );
752
753                let exec = req
754                    .extensions()
755                    .get::<Executor>()
756                    .cloned()
757                    .unwrap_or_default();
758
759                exec.spawn_task(
760                    async move {
761                        match upgrade::handle_upgrade(&req).await {
762                            Ok(upgraded) => {
763                                #[cfg(feature = "compression")]
764                                let maybe_ws_config = {
765                                    let mut ws_cfg = None;
766
767                                    tracing::trace!("check if pmd settings have to be applied to WS cfg...");
768
769                                    if let Some(Extension::PerMessageDeflate(pmd_cfg)) = req.extensions().get() {
770                                        tracing::trace!(
771                                            "apply accepted per-message-deflate cfg into WS server config: {pmd_cfg:?}"
772                                        );
773                                        ws_cfg = Some(WebSocketConfig {
774                                            per_message_deflate: Some(pmd_cfg.into()),
775                                            ..Default::default()
776                                        });
777                                    }
778
779                                    ws_cfg
780                                };
781
782                                #[cfg(not(feature = "compression"))]
783                                let maybe_ws_config = None;
784
785                                let socket =
786                                    AsyncWebSocket::from_raw_socket(upgraded, Role::Server, maybe_ws_config)
787                                        .await;
788
789                                let (parts, _) = req.into_parts();
790
791                                let server_socket = ServerWebSocket {
792                                    socket,
793                                    request: parts,
794                                };
795
796                                let _ = handler.serve( server_socket).await;
797                            }
798                            Err(e) => {
799                                tracing::error!("ws upgrade error: {e:?}");
800                            }
801                        }
802                    }
803                    .instrument(span),
804                );
805                Ok(resp)
806            }
807            Err(resp) => Ok(resp),
808        }
809    }
810}
811
812const ECHO_SERVICE_SUB_PROTOCOL_DEFAULT_STR: &str = "echo";
813
814/// Default protocol used by [`WebSocketEchoService`], incl when no match is found
815pub const ECHO_SERVICE_SUB_PROTOCOL_DEFAULT: NonEmptyStr =
816    non_empty_str!(ECHO_SERVICE_SUB_PROTOCOL_DEFAULT_STR);
817/// Uppercase all characters as part of the echod response in [`WebSocketEchoService`].
818pub const ECHO_SERVICE_SUB_PROTOCOL_UPPER: NonEmptyStr = non_empty_str!("echo-upper");
819/// Lowercase all characters as part of the echod response in [`WebSocketEchoService`].
820pub const ECHO_SERVICE_SUB_PROTOCOL_LOWER: NonEmptyStr = non_empty_str!("echo-lower");
821
822#[derive(Debug, Clone, Default)]
823#[non_exhaustive]
824/// Create a service which echos all incoming messages.
825pub struct WebSocketEchoService;
826
827impl WebSocketEchoService {
828    /// Create a new [`EchoWebSocketService`].
829    #[must_use]
830    pub fn new() -> Self {
831        Self
832    }
833}
834
835impl Service<AsyncWebSocket> for WebSocketEchoService {
836    type Output = ();
837    type Error = OpaqueError;
838
839    async fn serve(&self, mut socket: AsyncWebSocket) -> Result<Self::Output, Self::Error> {
840        let protocol = socket
841            .extensions()
842            .get::<headers::sec_websocket_protocol::AcceptedWebSocketProtocol>()
843            .map(|p| p.0.as_ref())
844            .unwrap_or(ECHO_SERVICE_SUB_PROTOCOL_DEFAULT_STR);
845
846        let transformer = if protocol.eq_ignore_ascii_case(&ECHO_SERVICE_SUB_PROTOCOL_LOWER) {
847            |msg: Message| match msg {
848                Message::Text(original) => Some(original.to_lowercase().into()),
849                msg @ Message::Binary(_) => Some(msg),
850                Message::Ping(_) | Message::Pong(_) | Message::Close(_) | Message::Frame(_) => None,
851            }
852        } else if protocol.eq_ignore_ascii_case(&ECHO_SERVICE_SUB_PROTOCOL_UPPER) {
853            |msg: Message| match msg {
854                Message::Text(original) => Some(original.to_uppercase().into()),
855                msg @ Message::Binary(_) => Some(msg),
856                Message::Ping(_) | Message::Pong(_) | Message::Close(_) | Message::Frame(_) => None,
857            }
858        } else {
859            |msg: Message| match msg {
860                msg @ (Message::Text(_) | Message::Binary(_)) => Some(msg),
861                Message::Ping(_) | Message::Pong(_) | Message::Close(_) | Message::Frame(_) => None,
862            }
863        };
864
865        loop {
866            let msg = socket.recv_message().await.context("recv next msg")?;
867            if let Some(msg2) = transformer(msg) {
868                socket.send_message(msg2).await.context("echo msg back")?;
869            }
870        }
871    }
872}
873
874impl Service<ServerWebSocket> for WebSocketEchoService {
875    type Output = ();
876    type Error = OpaqueError;
877
878    async fn serve(&self, socket: ServerWebSocket) -> Result<Self::Output, Self::Error> {
879        let socket = socket.into_inner();
880        self.serve(socket).await
881    }
882}
883
884impl Service<upgrade::Upgraded> for WebSocketEchoService {
885    type Output = ();
886    type Error = OpaqueError;
887
888    async fn serve(&self, io: upgrade::Upgraded) -> Result<Self::Output, Self::Error> {
889        #[cfg(not(feature = "compression"))]
890        let maybe_ws_config = {
891            if let Some(Extension::PerMessageDeflate(_)) = io.extensions().get() {
892                return Err(OpaqueError::from_display(
893                    "per-message-deflate is used but compression feature is disabled. Enable it if you wish to use this extension.",
894                ));
895            }
896            None
897        };
898
899        #[cfg(feature = "compression")]
900        let maybe_ws_config = {
901            let mut ws_cfg = None;
902
903            tracing::debug!("check if pmd settings have to be applied to WS cfg...");
904
905            if let Some(Extension::PerMessageDeflate(pmd_cfg)) = io.extensions().get() {
906                tracing::debug!(
907                    "apply accepted per-message-deflate cfg into WS server config: {pmd_cfg:?}"
908                );
909                ws_cfg = Some(WebSocketConfig {
910                    per_message_deflate: Some(pmd_cfg.into()),
911                    ..Default::default()
912                });
913            }
914
915            ws_cfg
916        };
917
918        let socket = AsyncWebSocket::from_raw_socket(io, Role::Server, maybe_ws_config).await;
919        self.serve(socket).await
920    }
921}
922
923#[cfg(test)]
924mod tests {
925    use headers::sec_websocket_protocol::AcceptedWebSocketProtocol;
926    use rama_http::Body;
927    use rama_utils::str::non_empty_str;
928
929    use super::*;
930
931    macro_rules! request {
932        (
933            $method:literal $version:literal $uri:literal
934            $(
935                $header_name:literal: $header_value:literal
936            )*
937        ) => {
938            request!(
939                $method $version $uri
940                $(
941                    $header_name: $header_value
942                )*
943                w/ []
944            )
945        };
946        (
947            $method:literal $version:literal $uri:literal
948            $(
949                $header_name:literal: $header_value:literal
950            )*
951            w/ [$($extension:expr),* $(,)?]
952        ) => {
953            {
954                let req = Request::builder()
955                    .uri($uri)
956                    .version(match $version {
957                        "HTTP/1.1" => Version::HTTP_11,
958                        "HTTP/2" => Version::HTTP_2,
959                        _ => unreachable!(),
960                    })
961                    .method(match $method {
962                        "GET" => Method::GET,
963                        "POST" => Method::POST,
964                        "CONNECT" => Method::CONNECT,
965                        _ => unreachable!(),
966                    });
967
968                $(
969                    let req = req.header($header_name, $header_value);
970                )*
971
972                $(
973                    let req = req.extension($extension);
974                )*
975
976                req.body(Body::empty()).unwrap()
977            }
978        };
979    }
980
981    fn assert_websocket_no_match(request: &Request, matcher: &WebSocketMatcher) {
982        assert!(
983            !matcher.matches(None, request),
984            "!({matcher:?}).matches({request:?})"
985        );
986    }
987
988    fn assert_websocket_match(request: &Request, matcher: &WebSocketMatcher) {
989        assert!(
990            matcher.matches(None, request),
991            "({matcher:?}).matches({request:?})"
992        );
993    }
994
995    #[test]
996    fn test_websocket_match_default_http_11() {
997        let matcher = WebSocketMatcher::default();
998
999        assert_websocket_no_match(
1000            &request! {
1001                "GET" "HTTP/1.1" "/"
1002            },
1003            &matcher,
1004        );
1005        assert_websocket_no_match(
1006            &request! {
1007                "GET" "HTTP/1.1" "/"
1008                "Upgrade": "websocket"
1009            },
1010            &matcher,
1011        );
1012        assert_websocket_no_match(
1013            &request! {
1014                "GET" "HTTP/1.1" "/"
1015                "Connection": "upgrade"
1016            },
1017            &matcher,
1018        );
1019        assert_websocket_match(
1020            &request! {
1021                "GET" "HTTP/1.1" "/"
1022                "Connection": "upgrade"
1023                "Upgrade": "websocket"
1024            },
1025            &matcher,
1026        );
1027    }
1028
1029    #[test]
1030    fn test_websocket_match_default_http_2() {
1031        let matcher = WebSocketMatcher::default();
1032
1033        assert_websocket_no_match(
1034            &request! {
1035                "GET" "HTTP/2" "/"
1036                "Connection": "upgrade"
1037                "Upgrade": "websocket"
1038                "Sec-WebSocket-Version": "13"
1039                "Sec-WebSocket-Key": "foobar"
1040            },
1041            &matcher,
1042        );
1043        assert_websocket_match(
1044            &request! {
1045                "CONNECT" "HTTP/2" "/"
1046                w/ [
1047                    Protocol::from_static("websocket"),
1048                ]
1049            },
1050            &matcher,
1051        );
1052        assert_websocket_no_match(
1053            &request! {
1054                "GET" "HTTP/2" "/"
1055                w/ [
1056                    Protocol::from_static("websocket"),
1057                ]
1058            },
1059            &matcher,
1060        );
1061    }
1062
1063    async fn assert_websocket_acceptor_ok(
1064        request: Request,
1065        acceptor: &WebSocketAcceptor,
1066        expected_accepted_protocol: Option<AcceptedWebSocketProtocol>,
1067    ) {
1068        let (resp, req) = acceptor.serve(request).await.unwrap();
1069        match req.version() {
1070            Version::HTTP_10 | Version::HTTP_11 => {
1071                assert_eq!(StatusCode::SWITCHING_PROTOCOLS, resp.status())
1072            }
1073            Version::HTTP_2 => assert_eq!(StatusCode::OK, resp.status()),
1074            _ => unreachable!(),
1075        }
1076        let accepted_protocol = resp
1077            .headers()
1078            .typed_get::<headers::SecWebSocketProtocol>()
1079            .map(|p| p.accept_first_protocol());
1080        if let Some(expected_accepted_protocol) = expected_accepted_protocol {
1081            assert_eq!(
1082                accepted_protocol.as_ref(),
1083                Some(&expected_accepted_protocol),
1084                "request = {req:?}"
1085            );
1086            assert_eq!(
1087                req.extensions().get::<AcceptedWebSocketProtocol>(),
1088                Some(&expected_accepted_protocol),
1089                "request = {req:?}"
1090            );
1091        } else {
1092            assert!(accepted_protocol.is_none());
1093            assert!(
1094                req.extensions()
1095                    .get::<AcceptedWebSocketProtocol>()
1096                    .is_none()
1097            );
1098        }
1099    }
1100
1101    async fn assert_websocket_acceptor_bad_request(request: Request, acceptor: &WebSocketAcceptor) {
1102        let resp = acceptor.serve(request).await.unwrap_err();
1103        assert_eq!(StatusCode::BAD_REQUEST, resp.status());
1104    }
1105
1106    #[tokio::test]
1107    async fn test_websocket_acceptor_default_http_2() {
1108        let acceptor = WebSocketAcceptor::default();
1109
1110        assert_websocket_acceptor_bad_request(
1111            request! {
1112                "GET" "HTTP/2" "/"
1113                "Connection": "upgrade"
1114                "Upgrade": "websocket"
1115                "Sec-WebSocket-Version": "13"
1116                "Sec-WebSocket-Key": "foobar"
1117            },
1118            &acceptor,
1119        )
1120        .await;
1121        assert_websocket_acceptor_bad_request(
1122            request! {
1123                "CONNECT" "HTTP/2" "/"
1124                w/ [
1125                    Protocol::from_static("websocket"),
1126                ]
1127            },
1128            &acceptor,
1129        )
1130        .await;
1131        assert_websocket_acceptor_bad_request(
1132            request! {
1133                "GET" "HTTP/2" "/"
1134                w/ [
1135                    Protocol::from_static("websocket"),
1136                ]
1137            },
1138            &acceptor,
1139        )
1140        .await;
1141
1142        assert_websocket_acceptor_ok(
1143            request! {
1144                "CONNECT" "HTTP/2" "/"
1145                "Sec-WebSocket-Version": "13"
1146                w/ [
1147                    Protocol::from_static("websocket"),
1148                ]
1149            },
1150            &acceptor,
1151            None,
1152        )
1153        .await;
1154
1155        assert_websocket_acceptor_bad_request(
1156            request! {
1157                "CONNECT" "HTTP/2" "/"
1158                "Sec-WebSocket-Version": "13"
1159                "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ=="
1160                "Sec-WebSocket-Protocol": "client"
1161                w/ [
1162                    Protocol::from_static("websocket"),
1163                ]
1164            },
1165            &acceptor,
1166        )
1167        .await;
1168    }
1169
1170    #[tokio::test]
1171    async fn test_websocket_acceptor_default_http_11() {
1172        let acceptor = WebSocketAcceptor::default();
1173
1174        assert_websocket_acceptor_bad_request(
1175            request! {
1176                "GET" "HTTP/1.1" "/"
1177                "Connection": "upgrade"
1178                "Upgrade": "websocket"
1179                "Sec-WebSocket-Version": "13"
1180                "Sec-WebSocket-Key": "foobar"
1181            },
1182            &acceptor,
1183        )
1184        .await;
1185
1186        assert_websocket_acceptor_bad_request(
1187            request! {
1188                "GET" "HTTP/1.1" "/"
1189                "Connection": "upgrade"
1190                "Upgrade": "websocket"
1191                "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ=="
1192            },
1193            &acceptor,
1194        )
1195        .await;
1196
1197        assert_websocket_acceptor_bad_request(
1198            request! {
1199                "GET" "HTTP/1.1" "/"
1200                "Connection": "upgrade"
1201                "Upgrade": "websocket"
1202                "Sec-WebSocket-Version": "14"
1203                "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ=="
1204            },
1205            &acceptor,
1206        )
1207        .await;
1208
1209        assert_websocket_acceptor_bad_request(
1210            request! {
1211                "GET" "HTTP/1.1" "/"
1212                "Connection": "upgrade"
1213                "Upgrade": "foo"
1214                "Sec-WebSocket-Version": "13"
1215                "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ=="
1216            },
1217            &acceptor,
1218        )
1219        .await;
1220
1221        assert_websocket_acceptor_bad_request(
1222            request! {
1223                "GET" "HTTP/1.1" "/"
1224                "Connection": "upgrade"
1225                "Sec-WebSocket-Version": "13"
1226                "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ=="
1227            },
1228            &acceptor,
1229        )
1230        .await;
1231
1232        assert_websocket_acceptor_bad_request(
1233            request! {
1234                "GET" "HTTP/1.1" "/"
1235                "Upgrade": "websocket"
1236                "Sec-WebSocket-Version": "13"
1237                "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ=="
1238            },
1239            &acceptor,
1240        )
1241        .await;
1242
1243        assert_websocket_acceptor_bad_request(
1244            request! {
1245                "GET" "HTTP/1.1" "/"
1246                "Connection": "keep-alive"
1247                "Upgrade": "websocket"
1248                "Sec-WebSocket-Version": "13"
1249                "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ=="
1250            },
1251            &acceptor,
1252        )
1253        .await;
1254
1255        assert_websocket_acceptor_ok(
1256            request! {
1257                "GET" "HTTP/1.1" "/"
1258                "Connection": "upgrade"
1259                "Upgrade": "websocket"
1260                "Sec-WebSocket-Version": "13"
1261                "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ=="
1262            },
1263            &acceptor,
1264            None,
1265        )
1266        .await;
1267    }
1268
1269    #[tokio::test]
1270    async fn test_websocket_accept_flex_protocols() {
1271        let acceptor = WebSocketAcceptor::default().with_protocols_flex(true);
1272
1273        // no protocols
1274
1275        assert_websocket_acceptor_ok(
1276            request! {
1277                "GET" "HTTP/1.1" "/"
1278                "Connection": "upgrade"
1279                "Upgrade": "websocket"
1280                "Sec-WebSocket-Version": "13"
1281                "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ=="
1282            },
1283            &acceptor,
1284            None,
1285        )
1286        .await;
1287        assert_websocket_acceptor_ok(
1288            request! {
1289                "CONNECT" "HTTP/2" "/"
1290                "Sec-WebSocket-Version": "13"
1291                w/ [
1292                    Protocol::from_static("websocket"),
1293                ]
1294            },
1295            &acceptor,
1296            None,
1297        )
1298        .await;
1299
1300        // with protocols
1301
1302        assert_websocket_acceptor_ok(
1303            request! {
1304                "GET" "HTTP/1.1" "/"
1305                "Connection": "upgrade"
1306                "Upgrade": "websocket"
1307                "Sec-WebSocket-Version": "13"
1308                "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ=="
1309                "Sec-WebSocket-Protocol": "foo"
1310            },
1311            &acceptor,
1312            Some(AcceptedWebSocketProtocol(non_empty_str!("foo"))),
1313        )
1314        .await;
1315        assert_websocket_acceptor_ok(
1316            request! {
1317                "CONNECT" "HTTP/2" "/"
1318                "Sec-WebSocket-Version": "13"
1319                "Sec-WebSocket-Protocol": "foo"
1320                w/ [
1321                    Protocol::from_static("websocket"),
1322                ]
1323            },
1324            &acceptor,
1325            Some(AcceptedWebSocketProtocol(non_empty_str!("foo"))),
1326        )
1327        .await;
1328
1329        // with multiple protocols
1330
1331        assert_websocket_acceptor_ok(
1332            request! {
1333                "GET" "HTTP/1.1" "/"
1334                "Connection": "upgrade"
1335                "Upgrade": "websocket"
1336                "Sec-WebSocket-Version": "13"
1337                "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ=="
1338                "Sec-WebSocket-Protocol": "foo, bar"
1339            },
1340            &acceptor,
1341            Some(AcceptedWebSocketProtocol(non_empty_str!("foo"))),
1342        )
1343        .await;
1344        assert_websocket_acceptor_ok(
1345            request! {
1346                "CONNECT" "HTTP/2" "/"
1347                "Sec-WebSocket-Version": "13"
1348                "Sec-WebSocket-Protocol": "foo,baz, foo"
1349                w/ [
1350                    Protocol::from_static("websocket"),
1351                ]
1352            },
1353            &acceptor,
1354            Some(AcceptedWebSocketProtocol(non_empty_str!("foo"))),
1355        )
1356        .await;
1357
1358        // without protocols, even though we have allow list, fine due to it being optional,
1359        // but we still only accept allowed protocols if defined
1360
1361        let acceptor =
1362            acceptor.with_protocols(headers::SecWebSocketProtocol::new(non_empty_str!("foo")));
1363
1364        assert_websocket_acceptor_ok(
1365            request! {
1366                "GET" "HTTP/1.1" "/"
1367                "Connection": "upgrade"
1368                "Upgrade": "websocket"
1369                "Sec-WebSocket-Version": "13"
1370                "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ=="
1371            },
1372            &acceptor,
1373            None,
1374        )
1375        .await;
1376
1377        assert_websocket_acceptor_bad_request(
1378            request! {
1379                "CONNECT" "HTTP/2" "/"
1380                "Sec-WebSocket-Version": "13"
1381                "Sec-WebSocket-Protocol": "baz,fo"
1382                w/ [
1383                    Protocol::from_static("websocket"),
1384                ]
1385            },
1386            &acceptor,
1387        )
1388        .await;
1389    }
1390
1391    #[tokio::test]
1392    async fn test_websocket_accept_required_protocols() {
1393        let acceptor = WebSocketAcceptor::default().with_protocols(headers::SecWebSocketProtocol(
1394            non_empty_smallvec![
1395                non_empty_str!("foo"),
1396                non_empty_str!("a"),
1397                non_empty_str!("b")
1398            ],
1399        ));
1400
1401        // no protocols, required so all bad
1402
1403        assert_websocket_acceptor_bad_request(
1404            request! {
1405                "GET" "HTTP/1.1" "/"
1406                "Connection": "upgrade"
1407                "Upgrade": "websocket"
1408                "Sec-WebSocket-Version": "13"
1409                "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ=="
1410            },
1411            &acceptor,
1412        )
1413        .await;
1414        assert_websocket_acceptor_bad_request(
1415            request! {
1416                "CONNECT" "HTTP/2" "/"
1417                "Sec-WebSocket-Version": "13"
1418                w/ [
1419                    Protocol::from_static("websocket"),
1420                ]
1421            },
1422            &acceptor,
1423        )
1424        .await;
1425
1426        // with allowed protocol
1427
1428        assert_websocket_acceptor_ok(
1429            request! {
1430                "GET" "HTTP/1.1" "/"
1431                "Connection": "upgrade"
1432                "Upgrade": "websocket"
1433                "Sec-WebSocket-Version": "13"
1434                "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ=="
1435                "Sec-WebSocket-Protocol": "foo"
1436            },
1437            &acceptor,
1438            Some(AcceptedWebSocketProtocol(non_empty_str!("foo"))),
1439        )
1440        .await;
1441        assert_websocket_acceptor_ok(
1442            request! {
1443                "CONNECT" "HTTP/2" "/"
1444                "Sec-WebSocket-Version": "13"
1445                "Sec-WebSocket-Protocol": "b"
1446                w/ [
1447                    Protocol::from_static("websocket"),
1448                ]
1449            },
1450            &acceptor,
1451            Some(AcceptedWebSocketProtocol(non_empty_str!("b"))),
1452        )
1453        .await;
1454
1455        // with multiple protocols (including at least one allowed one)
1456
1457        assert_websocket_acceptor_ok(
1458            request! {
1459                "GET" "HTTP/1.1" "/"
1460                "Connection": "upgrade"
1461                "Upgrade": "websocket"
1462                "Sec-WebSocket-Version": "13"
1463                "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ=="
1464                "Sec-WebSocket-Protocol": "test, b"
1465            },
1466            &acceptor,
1467            Some(AcceptedWebSocketProtocol(non_empty_str!("b"))),
1468        )
1469        .await;
1470        assert_websocket_acceptor_ok(
1471            request! {
1472                "CONNECT" "HTTP/2" "/"
1473                "Sec-WebSocket-Version": "13"
1474                "Sec-WebSocket-Protocol": "a,test, c"
1475                w/ [
1476                    Protocol::from_static("websocket"),
1477                ]
1478            },
1479            &acceptor,
1480            Some(AcceptedWebSocketProtocol(non_empty_str!("a"))),
1481        )
1482        .await;
1483
1484        // only with non-allowed protocol(s)
1485
1486        assert_websocket_acceptor_bad_request(
1487            request! {
1488                "GET" "HTTP/1.1" "/"
1489                "Connection": "upgrade"
1490                "Upgrade": "websocket"
1491                "Sec-WebSocket-Version": "13"
1492                "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ=="
1493                "Sec-WebSocket-Protocol": "test, c"
1494            },
1495            &acceptor,
1496        )
1497        .await;
1498        assert_websocket_acceptor_bad_request(
1499            request! {
1500                "CONNECT" "HTTP/2" "/"
1501                "Sec-WebSocket-Version": "13"
1502                "Sec-WebSocket-Protocol": "test"
1503                w/ [
1504                    Protocol::from_static("websocket"),
1505                ]
1506            },
1507            &acceptor,
1508        )
1509        .await;
1510    }
1511}