1use self::rejection::*;
94use super::FromRequestParts;
95use crate::{body::Bytes, response::Response, Error};
96use axum_core::body::Body;
97use futures_util::{
98 sink::{Sink, SinkExt},
99 stream::{FusedStream, Stream, StreamExt},
100};
101use http::{
102 header::{self, HeaderMap, HeaderName, HeaderValue},
103 request::Parts,
104 Method, StatusCode, Version,
105};
106use hyper_util::rt::TokioIo;
107use sha1::{Digest, Sha1};
108use std::{
109 borrow::Cow,
110 future::Future,
111 pin::Pin,
112 task::{ready, Context, Poll},
113};
114use tokio_tungstenite::{
115 tungstenite::{
116 self as ts,
117 protocol::{self, WebSocketConfig},
118 },
119 WebSocketStream,
120};
121
122#[cfg_attr(docsrs, doc(cfg(feature = "ws")))]
132#[must_use]
133pub struct WebSocketUpgrade<F = DefaultOnFailedUpgrade> {
134 config: WebSocketConfig,
135 protocol: Option<HeaderValue>,
137 sec_websocket_key: Option<HeaderValue>,
139 on_upgrade: hyper::upgrade::OnUpgrade,
140 on_failed_upgrade: F,
141 sec_websocket_protocol: Option<HeaderValue>,
142}
143
144impl<F> std::fmt::Debug for WebSocketUpgrade<F> {
145 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
146 f.debug_struct("WebSocketUpgrade")
147 .field("config", &self.config)
148 .field("protocol", &self.protocol)
149 .field("sec_websocket_key", &self.sec_websocket_key)
150 .field("sec_websocket_protocol", &self.sec_websocket_protocol)
151 .finish_non_exhaustive()
152 }
153}
154
155impl<F> WebSocketUpgrade<F> {
156 pub fn read_buffer_size(mut self, size: usize) -> Self {
158 self.config.read_buffer_size = size;
159 self
160 }
161
162 pub fn write_buffer_size(mut self, size: usize) -> Self {
172 self.config.write_buffer_size = size;
173 self
174 }
175
176 pub fn max_write_buffer_size(mut self, max: usize) -> Self {
188 self.config.max_write_buffer_size = max;
189 self
190 }
191
192 pub fn max_message_size(mut self, max: usize) -> Self {
194 self.config.max_message_size = Some(max);
195 self
196 }
197
198 pub fn max_frame_size(mut self, max: usize) -> Self {
200 self.config.max_frame_size = Some(max);
201 self
202 }
203
204 pub fn accept_unmasked_frames(mut self, accept: bool) -> Self {
206 self.config.accept_unmasked_frames = accept;
207 self
208 }
209
210 pub fn protocols<I>(mut self, protocols: I) -> Self
241 where
242 I: IntoIterator,
243 I::Item: Into<Cow<'static, str>>,
244 {
245 if let Some(req_protocols) = self
246 .sec_websocket_protocol
247 .as_ref()
248 .and_then(|p| p.to_str().ok())
249 {
250 self.protocol = protocols
251 .into_iter()
252 .map(Into::into)
255 .find(|protocol| {
256 req_protocols
257 .split(',')
258 .any(|req_protocol| req_protocol.trim() == protocol)
259 })
260 .map(|protocol| match protocol {
261 Cow::Owned(s) => HeaderValue::from_str(&s).unwrap(),
262 Cow::Borrowed(s) => HeaderValue::from_static(s),
263 });
264 }
265
266 self
267 }
268
269 pub fn selected_protocol(&self) -> Option<&HeaderValue> {
275 self.protocol.as_ref()
276 }
277
278 pub fn on_failed_upgrade<C>(self, callback: C) -> WebSocketUpgrade<C>
303 where
304 C: OnFailedUpgrade,
305 {
306 WebSocketUpgrade {
307 config: self.config,
308 protocol: self.protocol,
309 sec_websocket_key: self.sec_websocket_key,
310 on_upgrade: self.on_upgrade,
311 on_failed_upgrade: callback,
312 sec_websocket_protocol: self.sec_websocket_protocol,
313 }
314 }
315
316 #[must_use = "to set up the WebSocket connection, this response must be returned"]
319 pub fn on_upgrade<C, Fut>(self, callback: C) -> Response
320 where
321 C: FnOnce(WebSocket) -> Fut + Send + 'static,
322 Fut: Future<Output = ()> + Send + 'static,
323 F: OnFailedUpgrade,
324 {
325 let on_upgrade = self.on_upgrade;
326 let config = self.config;
327 let on_failed_upgrade = self.on_failed_upgrade;
328
329 let protocol = self.protocol.clone();
330
331 tokio::spawn(async move {
332 let upgraded = match on_upgrade.await {
333 Ok(upgraded) => upgraded,
334 Err(err) => {
335 on_failed_upgrade.call(Error::new(err));
336 return;
337 }
338 };
339 let upgraded = TokioIo::new(upgraded);
340
341 let socket =
342 WebSocketStream::from_raw_socket(upgraded, protocol::Role::Server, Some(config))
343 .await;
344 let socket = WebSocket {
345 inner: socket,
346 protocol,
347 };
348 callback(socket).await;
349 });
350
351 let mut response = if let Some(sec_websocket_key) = &self.sec_websocket_key {
352 #[allow(clippy::declare_interior_mutable_const)]
355 const UPGRADE: HeaderValue = HeaderValue::from_static("upgrade");
356 #[allow(clippy::declare_interior_mutable_const)]
357 const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket");
358
359 Response::builder()
360 .status(StatusCode::SWITCHING_PROTOCOLS)
361 .header(header::CONNECTION, UPGRADE)
362 .header(header::UPGRADE, WEBSOCKET)
363 .header(
364 header::SEC_WEBSOCKET_ACCEPT,
365 sign(sec_websocket_key.as_bytes()),
366 )
367 .body(Body::empty())
368 .unwrap()
369 } else {
370 Response::new(Body::empty())
374 };
375
376 if let Some(protocol) = self.protocol {
377 response
378 .headers_mut()
379 .insert(header::SEC_WEBSOCKET_PROTOCOL, protocol);
380 }
381
382 response
383 }
384}
385
386pub trait OnFailedUpgrade: Send + 'static {
390 fn call(self, error: Error);
392}
393
394impl<F> OnFailedUpgrade for F
395where
396 F: FnOnce(Error) + Send + 'static,
397{
398 fn call(self, error: Error) {
399 self(error)
400 }
401}
402
403#[non_exhaustive]
407#[derive(Debug)]
408pub struct DefaultOnFailedUpgrade;
409
410impl OnFailedUpgrade for DefaultOnFailedUpgrade {
411 #[inline]
412 fn call(self, _error: Error) {}
413}
414
415impl<S> FromRequestParts<S> for WebSocketUpgrade<DefaultOnFailedUpgrade>
416where
417 S: Send + Sync,
418{
419 type Rejection = WebSocketUpgradeRejection;
420
421 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
422 let sec_websocket_key = if parts.version <= Version::HTTP_11 {
423 if parts.method != Method::GET {
424 return Err(MethodNotGet.into());
425 }
426
427 if !header_contains(&parts.headers, header::CONNECTION, "upgrade") {
428 return Err(InvalidConnectionHeader.into());
429 }
430
431 if !header_eq(&parts.headers, header::UPGRADE, "websocket") {
432 return Err(InvalidUpgradeHeader.into());
433 }
434
435 Some(
436 parts
437 .headers
438 .get(header::SEC_WEBSOCKET_KEY)
439 .ok_or(WebSocketKeyHeaderMissing)?
440 .clone(),
441 )
442 } else {
443 if parts.method != Method::CONNECT {
444 return Err(MethodNotConnect.into());
445 }
446
447 #[cfg(feature = "http2")]
450 if parts
451 .extensions
452 .get::<hyper::ext::Protocol>()
453 .map_or(true, |p| p.as_str() != "websocket")
454 {
455 return Err(InvalidProtocolPseudoheader.into());
456 }
457
458 None
459 };
460
461 if !header_eq(&parts.headers, header::SEC_WEBSOCKET_VERSION, "13") {
462 return Err(InvalidWebSocketVersionHeader.into());
463 }
464
465 let on_upgrade = parts
466 .extensions
467 .remove::<hyper::upgrade::OnUpgrade>()
468 .ok_or(ConnectionNotUpgradable)?;
469
470 let sec_websocket_protocol = parts.headers.get(header::SEC_WEBSOCKET_PROTOCOL).cloned();
471
472 Ok(Self {
473 config: Default::default(),
474 protocol: None,
475 sec_websocket_key,
476 on_upgrade,
477 sec_websocket_protocol,
478 on_failed_upgrade: DefaultOnFailedUpgrade,
479 })
480 }
481}
482
483fn header_eq(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
484 if let Some(header) = headers.get(&key) {
485 header.as_bytes().eq_ignore_ascii_case(value.as_bytes())
486 } else {
487 false
488 }
489}
490
491fn header_contains(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
492 let header = if let Some(header) = headers.get(&key) {
493 header
494 } else {
495 return false;
496 };
497
498 if let Ok(header) = std::str::from_utf8(header.as_bytes()) {
499 header.to_ascii_lowercase().contains(value)
500 } else {
501 false
502 }
503}
504
505#[derive(Debug)]
509pub struct WebSocket {
510 inner: WebSocketStream<TokioIo<hyper::upgrade::Upgraded>>,
511 protocol: Option<HeaderValue>,
512}
513
514impl WebSocket {
515 pub async fn recv(&mut self) -> Option<Result<Message, Error>> {
519 self.next().await
520 }
521
522 pub async fn send(&mut self, msg: Message) -> Result<(), Error> {
524 self.inner
525 .send(msg.into_tungstenite())
526 .await
527 .map_err(Error::new)
528 }
529
530 pub fn protocol(&self) -> Option<&HeaderValue> {
532 self.protocol.as_ref()
533 }
534}
535
536impl FusedStream for WebSocket {
537 fn is_terminated(&self) -> bool {
539 self.inner.is_terminated()
540 }
541}
542
543impl Stream for WebSocket {
544 type Item = Result<Message, Error>;
545
546 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
547 loop {
548 match ready!(self.inner.poll_next_unpin(cx)) {
549 Some(Ok(msg)) => {
550 if let Some(msg) = Message::from_tungstenite(msg) {
551 return Poll::Ready(Some(Ok(msg)));
552 }
553 }
554 Some(Err(err)) => return Poll::Ready(Some(Err(Error::new(err)))),
555 None => return Poll::Ready(None),
556 }
557 }
558 }
559}
560
561impl Sink<Message> for WebSocket {
562 type Error = Error;
563
564 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
565 Pin::new(&mut self.inner).poll_ready(cx).map_err(Error::new)
566 }
567
568 fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
569 Pin::new(&mut self.inner)
570 .start_send(item.into_tungstenite())
571 .map_err(Error::new)
572 }
573
574 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
575 Pin::new(&mut self.inner).poll_flush(cx).map_err(Error::new)
576 }
577
578 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
579 Pin::new(&mut self.inner).poll_close(cx).map_err(Error::new)
580 }
581}
582
583#[derive(Debug, Clone, PartialEq, Eq, Default)]
587pub struct Utf8Bytes(ts::Utf8Bytes);
588
589impl Utf8Bytes {
590 #[inline]
592 #[must_use]
593 pub const fn from_static(str: &'static str) -> Self {
594 Self(ts::Utf8Bytes::from_static(str))
595 }
596
597 #[inline]
599 pub fn as_str(&self) -> &str {
600 self.0.as_str()
601 }
602
603 fn into_tungstenite(self) -> ts::Utf8Bytes {
604 self.0
605 }
606}
607
608impl std::ops::Deref for Utf8Bytes {
609 type Target = str;
610
611 #[inline]
624 fn deref(&self) -> &Self::Target {
625 self.as_str()
626 }
627}
628
629impl std::fmt::Display for Utf8Bytes {
630 #[inline]
631 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
632 f.write_str(self.as_str())
633 }
634}
635
636impl TryFrom<Bytes> for Utf8Bytes {
637 type Error = std::str::Utf8Error;
638
639 #[inline]
640 fn try_from(bytes: Bytes) -> Result<Self, Self::Error> {
641 Ok(Self(bytes.try_into()?))
642 }
643}
644
645impl TryFrom<Vec<u8>> for Utf8Bytes {
646 type Error = std::str::Utf8Error;
647
648 #[inline]
649 fn try_from(v: Vec<u8>) -> Result<Self, Self::Error> {
650 Ok(Self(v.try_into()?))
651 }
652}
653
654impl From<String> for Utf8Bytes {
655 #[inline]
656 fn from(s: String) -> Self {
657 Self(s.into())
658 }
659}
660
661impl From<&str> for Utf8Bytes {
662 #[inline]
663 fn from(s: &str) -> Self {
664 Self(s.into())
665 }
666}
667
668impl From<&String> for Utf8Bytes {
669 #[inline]
670 fn from(s: &String) -> Self {
671 Self(s.into())
672 }
673}
674
675impl From<Utf8Bytes> for Bytes {
676 #[inline]
677 fn from(Utf8Bytes(bytes): Utf8Bytes) -> Self {
678 bytes.into()
679 }
680}
681
682impl<T> PartialEq<T> for Utf8Bytes
683where
684 for<'a> &'a str: PartialEq<T>,
685{
686 #[inline]
694 fn eq(&self, other: &T) -> bool {
695 self.as_str() == *other
696 }
697}
698
699pub type CloseCode = u16;
701
702#[derive(Debug, Clone, Eq, PartialEq)]
704pub struct CloseFrame {
705 pub code: CloseCode,
707 pub reason: Utf8Bytes,
709}
710
711#[derive(Debug, Eq, PartialEq, Clone)]
735pub enum Message {
736 Text(Utf8Bytes),
738 Binary(Bytes),
740 Ping(Bytes),
747 Pong(Bytes),
755 Close(Option<CloseFrame>),
775}
776
777impl Message {
778 fn into_tungstenite(self) -> ts::Message {
779 match self {
780 Self::Text(text) => ts::Message::Text(text.into_tungstenite()),
781 Self::Binary(binary) => ts::Message::Binary(binary),
782 Self::Ping(ping) => ts::Message::Ping(ping),
783 Self::Pong(pong) => ts::Message::Pong(pong),
784 Self::Close(Some(close)) => ts::Message::Close(Some(ts::protocol::CloseFrame {
785 code: ts::protocol::frame::coding::CloseCode::from(close.code),
786 reason: close.reason.into_tungstenite(),
787 })),
788 Self::Close(None) => ts::Message::Close(None),
789 }
790 }
791
792 fn from_tungstenite(message: ts::Message) -> Option<Self> {
793 match message {
794 ts::Message::Text(text) => Some(Self::Text(Utf8Bytes(text))),
795 ts::Message::Binary(binary) => Some(Self::Binary(binary)),
796 ts::Message::Ping(ping) => Some(Self::Ping(ping)),
797 ts::Message::Pong(pong) => Some(Self::Pong(pong)),
798 ts::Message::Close(Some(close)) => Some(Self::Close(Some(CloseFrame {
799 code: close.code.into(),
800 reason: Utf8Bytes(close.reason),
801 }))),
802 ts::Message::Close(None) => Some(Self::Close(None)),
803 ts::Message::Frame(_) => None,
806 }
807 }
808
809 pub fn into_data(self) -> Bytes {
811 match self {
812 Self::Text(string) => Bytes::from(string),
813 Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => data,
814 Self::Close(None) => Bytes::new(),
815 Self::Close(Some(frame)) => Bytes::from(frame.reason),
816 }
817 }
818
819 pub fn into_text(self) -> Result<Utf8Bytes, Error> {
821 match self {
822 Self::Text(string) => Ok(string),
823 Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => {
824 Ok(Utf8Bytes::try_from(data).map_err(Error::new)?)
825 }
826 Self::Close(None) => Ok(Utf8Bytes::default()),
827 Self::Close(Some(frame)) => Ok(frame.reason),
828 }
829 }
830
831 pub fn to_text(&self) -> Result<&str, Error> {
834 match *self {
835 Self::Text(ref string) => Ok(string.as_str()),
836 Self::Binary(ref data) | Self::Ping(ref data) | Self::Pong(ref data) => {
837 Ok(std::str::from_utf8(data).map_err(Error::new)?)
838 }
839 Self::Close(None) => Ok(""),
840 Self::Close(Some(ref frame)) => Ok(&frame.reason),
841 }
842 }
843
844 pub fn text<S>(string: S) -> Message
846 where
847 S: Into<Utf8Bytes>,
848 {
849 Message::Text(string.into())
850 }
851
852 pub fn binary<B>(bin: B) -> Message
854 where
855 B: Into<Bytes>,
856 {
857 Message::Binary(bin.into())
858 }
859}
860
861impl From<String> for Message {
862 fn from(string: String) -> Self {
863 Message::Text(string.into())
864 }
865}
866
867impl<'s> From<&'s str> for Message {
868 fn from(string: &'s str) -> Self {
869 Message::Text(string.into())
870 }
871}
872
873impl<'b> From<&'b [u8]> for Message {
874 fn from(data: &'b [u8]) -> Self {
875 Message::Binary(Bytes::copy_from_slice(data))
876 }
877}
878
879impl From<Bytes> for Message {
880 fn from(data: Bytes) -> Self {
881 Message::Binary(data)
882 }
883}
884
885impl From<Vec<u8>> for Message {
886 fn from(data: Vec<u8>) -> Self {
887 Message::Binary(data.into())
888 }
889}
890
891impl From<Message> for Vec<u8> {
892 fn from(msg: Message) -> Self {
893 msg.into_data().to_vec()
894 }
895}
896
897fn sign(key: &[u8]) -> HeaderValue {
898 use base64::engine::Engine as _;
899
900 let mut sha1 = Sha1::default();
901 sha1.update(key);
902 sha1.update(&b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"[..]);
903 let b64 = Bytes::from(base64::engine::general_purpose::STANDARD.encode(sha1.finalize()));
904 HeaderValue::from_maybe_shared(b64).expect("base64 is a valid value")
905}
906
907pub mod rejection {
908 use axum_core::__composite_rejection as composite_rejection;
911 use axum_core::__define_rejection as define_rejection;
912
913 define_rejection! {
914 #[status = METHOD_NOT_ALLOWED]
915 #[body = "Request method must be `GET`"]
916 pub struct MethodNotGet;
918 }
919
920 define_rejection! {
921 #[status = METHOD_NOT_ALLOWED]
922 #[body = "Request method must be `CONNECT`"]
923 pub struct MethodNotConnect;
925 }
926
927 define_rejection! {
928 #[status = BAD_REQUEST]
929 #[body = "Connection header did not include 'upgrade'"]
930 pub struct InvalidConnectionHeader;
932 }
933
934 define_rejection! {
935 #[status = BAD_REQUEST]
936 #[body = "`Upgrade` header did not include 'websocket'"]
937 pub struct InvalidUpgradeHeader;
939 }
940
941 define_rejection! {
942 #[status = BAD_REQUEST]
943 #[body = "`:protocol` pseudo-header did not include 'websocket'"]
944 pub struct InvalidProtocolPseudoheader;
946 }
947
948 define_rejection! {
949 #[status = BAD_REQUEST]
950 #[body = "`Sec-WebSocket-Version` header did not include '13'"]
951 pub struct InvalidWebSocketVersionHeader;
953 }
954
955 define_rejection! {
956 #[status = BAD_REQUEST]
957 #[body = "`Sec-WebSocket-Key` header missing"]
958 pub struct WebSocketKeyHeaderMissing;
960 }
961
962 define_rejection! {
963 #[status = UPGRADE_REQUIRED]
964 #[body = "WebSocket request couldn't be upgraded since no upgrade state was present"]
965 pub struct ConnectionNotUpgradable;
974 }
975
976 composite_rejection! {
977 pub enum WebSocketUpgradeRejection {
982 MethodNotGet,
983 MethodNotConnect,
984 InvalidConnectionHeader,
985 InvalidUpgradeHeader,
986 InvalidProtocolPseudoheader,
987 InvalidWebSocketVersionHeader,
988 WebSocketKeyHeaderMissing,
989 ConnectionNotUpgradable,
990 }
991 }
992}
993
994pub mod close_code {
995 pub const NORMAL: u16 = 1000;
1002
1003 pub const AWAY: u16 = 1001;
1006
1007 pub const PROTOCOL: u16 = 1002;
1009
1010 pub const UNSUPPORTED: u16 = 1003;
1015
1016 pub const STATUS: u16 = 1005;
1018
1019 pub const ABNORMAL: u16 = 1006;
1021
1022 pub const INVALID: u16 = 1007;
1027
1028 pub const POLICY: u16 = 1008;
1035
1036 pub const SIZE: u16 = 1009;
1039
1040 pub const EXTENSION: u16 = 1010;
1049
1050 pub const ERROR: u16 = 1011;
1053
1054 pub const RESTART: u16 = 1012;
1056
1057 pub const AGAIN: u16 = 1013;
1061}
1062
1063#[cfg(test)]
1064mod tests {
1065 use std::future::ready;
1066
1067 use super::*;
1068 use crate::{routing::any, test_helpers::spawn_service, Router};
1069 use http::{Request, Version};
1070 use http_body_util::BodyExt as _;
1071 use hyper_util::rt::TokioExecutor;
1072 use tokio::io::{AsyncRead, AsyncWrite};
1073 use tokio::net::TcpStream;
1074 use tokio_tungstenite::tungstenite;
1075 use tower::ServiceExt;
1076
1077 #[crate::test]
1078 async fn rejects_http_1_0_requests() {
1079 let svc = any(|ws: Result<WebSocketUpgrade, WebSocketUpgradeRejection>| {
1080 let rejection = ws.unwrap_err();
1081 assert!(matches!(
1082 rejection,
1083 WebSocketUpgradeRejection::ConnectionNotUpgradable(_)
1084 ));
1085 std::future::ready(())
1086 });
1087
1088 let req = Request::builder()
1089 .version(Version::HTTP_10)
1090 .method(Method::GET)
1091 .header("upgrade", "websocket")
1092 .header("connection", "Upgrade")
1093 .header("sec-websocket-key", "6D69KGBOr4Re+Nj6zx9aQA==")
1094 .header("sec-websocket-version", "13")
1095 .body(Body::empty())
1096 .unwrap();
1097
1098 let res = svc.oneshot(req).await.unwrap();
1099
1100 assert_eq!(res.status(), StatusCode::OK);
1101 }
1102
1103 #[allow(dead_code)]
1104 fn default_on_failed_upgrade() {
1105 async fn handler(ws: WebSocketUpgrade) -> Response {
1106 ws.on_upgrade(|_| async {})
1107 }
1108 let _: Router = Router::new().route("/", any(handler));
1109 }
1110
1111 #[allow(dead_code)]
1112 fn on_failed_upgrade() {
1113 async fn handler(ws: WebSocketUpgrade) -> Response {
1114 ws.on_failed_upgrade(|_error: Error| println!("oops!"))
1115 .on_upgrade(|_| async {})
1116 }
1117 let _: Router = Router::new().route("/", any(handler));
1118 }
1119
1120 #[crate::test]
1121 async fn integration_test() {
1122 let addr = spawn_service(echo_app());
1123 let uri = format!("ws://{addr}/echo").try_into().unwrap();
1124 let req = tungstenite::client::ClientRequestBuilder::new(uri)
1125 .with_sub_protocol(TEST_ECHO_APP_REQ_SUBPROTO);
1126 let (socket, response) = tokio_tungstenite::connect_async(req).await.unwrap();
1127 test_echo_app(socket, response.headers()).await;
1128 }
1129
1130 #[crate::test]
1131 #[cfg(feature = "http2")]
1132 async fn http2() {
1133 let addr = spawn_service(echo_app());
1134 let io = TokioIo::new(TcpStream::connect(addr).await.unwrap());
1135 let (mut send_request, conn) =
1136 hyper::client::conn::http2::Builder::new(TokioExecutor::new())
1137 .handshake(io)
1138 .await
1139 .unwrap();
1140
1141 for _ in 0..10 {
1143 tokio::task::yield_now().await;
1144 }
1145 assert!(conn.is_extended_connect_protocol_enabled());
1146 tokio::spawn(async {
1147 conn.await.unwrap();
1148 });
1149
1150 let req = Request::builder()
1151 .method(Method::CONNECT)
1152 .extension(hyper::ext::Protocol::from_static("websocket"))
1153 .uri("/echo")
1154 .header("sec-websocket-version", "13")
1155 .header("sec-websocket-protocol", TEST_ECHO_APP_REQ_SUBPROTO)
1156 .header("Host", "server.example.com")
1157 .body(Body::empty())
1158 .unwrap();
1159
1160 let mut response = send_request.send_request(req).await.unwrap();
1161 let status = response.status();
1162 if status != 200 {
1163 let body = response.into_body().collect().await.unwrap().to_bytes();
1164 let body = std::str::from_utf8(&body).unwrap();
1165 panic!("response status was {status}: {body}");
1166 }
1167 let upgraded = hyper::upgrade::on(&mut response).await.unwrap();
1168 let upgraded = TokioIo::new(upgraded);
1169 let socket = WebSocketStream::from_raw_socket(upgraded, protocol::Role::Client, None).await;
1170 test_echo_app(socket, response.headers()).await;
1171 }
1172
1173 fn echo_app() -> Router {
1174 async fn handle_socket(mut socket: WebSocket) {
1175 assert_eq!(socket.protocol().unwrap(), "echo");
1176 while let Some(Ok(msg)) = socket.recv().await {
1177 match msg {
1178 Message::Text(_) | Message::Binary(_) | Message::Close(_) => {
1179 if socket.send(msg).await.is_err() {
1180 break;
1181 }
1182 }
1183 Message::Ping(_) | Message::Pong(_) => {
1184 }
1186 }
1187 }
1188 }
1189
1190 Router::new().route(
1191 "/echo",
1192 any(|ws: WebSocketUpgrade| {
1193 let ws = ws.protocols(["echo2", "echo"]);
1194 assert_eq!(ws.selected_protocol().unwrap(), "echo");
1195 ready(ws.on_upgrade(handle_socket))
1196 }),
1197 )
1198 }
1199
1200 const TEST_ECHO_APP_REQ_SUBPROTO: &str = "echo3, echo";
1201 async fn test_echo_app<S: AsyncRead + AsyncWrite + Unpin>(
1202 mut socket: WebSocketStream<S>,
1203 headers: &http::HeaderMap,
1204 ) {
1205 assert_eq!(headers[http::header::SEC_WEBSOCKET_PROTOCOL], "echo");
1206
1207 let input = tungstenite::Message::Text(tungstenite::Utf8Bytes::from_static("foobar"));
1208 socket.send(input.clone()).await.unwrap();
1209 let output = socket.next().await.unwrap().unwrap();
1210 assert_eq!(input, output);
1211
1212 socket
1213 .send(tungstenite::Message::Ping(Bytes::from_static(b"ping")))
1214 .await
1215 .unwrap();
1216 let output = socket.next().await.unwrap().unwrap();
1217 assert_eq!(
1218 output,
1219 tungstenite::Message::Pong(Bytes::from_static(b"ping"))
1220 );
1221 }
1222}