1use self::rejection::*;
94use super::FromRequestParts;
95use crate::{body::Bytes, response::Response, Error};
96use axum_core::body::Body;
97use futures_util::{
98 sink::{Sink, SinkExt},
99 stream::{Stream, StreamExt},
100};
101use http::{
102 header::{self, HeaderMap, HeaderName, HeaderValue},
103 request::Parts,
104 Method, StatusCode, Version,
105};
106use hyper_util::rt::TokioIo;
107use sha1::{Digest, Sha1};
108use std::{
109 borrow::Cow,
110 future::Future,
111 pin::Pin,
112 task::{ready, Context, Poll},
113};
114use tokio_tungstenite::{
115 tungstenite::{
116 self as ts,
117 protocol::{self, WebSocketConfig},
118 },
119 WebSocketStream,
120};
121
122#[cfg_attr(docsrs, doc(cfg(feature = "ws")))]
132pub struct WebSocketUpgrade<F = DefaultOnFailedUpgrade> {
133 config: WebSocketConfig,
134 protocol: Option<HeaderValue>,
136 sec_websocket_key: Option<HeaderValue>,
138 on_upgrade: hyper::upgrade::OnUpgrade,
139 on_failed_upgrade: F,
140 sec_websocket_protocol: Option<HeaderValue>,
141}
142
143impl<F> std::fmt::Debug for WebSocketUpgrade<F> {
144 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145 f.debug_struct("WebSocketUpgrade")
146 .field("config", &self.config)
147 .field("protocol", &self.protocol)
148 .field("sec_websocket_key", &self.sec_websocket_key)
149 .field("sec_websocket_protocol", &self.sec_websocket_protocol)
150 .finish_non_exhaustive()
151 }
152}
153
154impl<F> WebSocketUpgrade<F> {
155 pub fn read_buffer_size(mut self, size: usize) -> Self {
157 self.config.read_buffer_size = size;
158 self
159 }
160
161 pub fn write_buffer_size(mut self, size: usize) -> Self {
171 self.config.write_buffer_size = size;
172 self
173 }
174
175 pub fn max_write_buffer_size(mut self, max: usize) -> Self {
187 self.config.max_write_buffer_size = max;
188 self
189 }
190
191 pub fn max_message_size(mut self, max: usize) -> Self {
193 self.config.max_message_size = Some(max);
194 self
195 }
196
197 pub fn max_frame_size(mut self, max: usize) -> Self {
199 self.config.max_frame_size = Some(max);
200 self
201 }
202
203 pub fn accept_unmasked_frames(mut self, accept: bool) -> Self {
205 self.config.accept_unmasked_frames = accept;
206 self
207 }
208
209 pub fn protocols<I>(mut self, protocols: I) -> Self
240 where
241 I: IntoIterator,
242 I::Item: Into<Cow<'static, str>>,
243 {
244 if let Some(req_protocols) = self
245 .sec_websocket_protocol
246 .as_ref()
247 .and_then(|p| p.to_str().ok())
248 {
249 self.protocol = protocols
250 .into_iter()
251 .map(Into::into)
254 .find(|protocol| {
255 req_protocols
256 .split(',')
257 .any(|req_protocol| req_protocol.trim() == protocol)
258 })
259 .map(|protocol| match protocol {
260 Cow::Owned(s) => HeaderValue::from_str(&s).unwrap(),
261 Cow::Borrowed(s) => HeaderValue::from_static(s),
262 });
263 }
264
265 self
266 }
267
268 pub fn selected_protocol(&self) -> Option<&HeaderValue> {
274 self.protocol.as_ref()
275 }
276
277 pub fn on_failed_upgrade<C>(self, callback: C) -> WebSocketUpgrade<C>
302 where
303 C: OnFailedUpgrade,
304 {
305 WebSocketUpgrade {
306 config: self.config,
307 protocol: self.protocol,
308 sec_websocket_key: self.sec_websocket_key,
309 on_upgrade: self.on_upgrade,
310 on_failed_upgrade: callback,
311 sec_websocket_protocol: self.sec_websocket_protocol,
312 }
313 }
314
315 #[must_use = "to set up the WebSocket connection, this response must be returned"]
318 pub fn on_upgrade<C, Fut>(self, callback: C) -> Response
319 where
320 C: FnOnce(WebSocket) -> Fut + Send + 'static,
321 Fut: Future<Output = ()> + Send + 'static,
322 F: OnFailedUpgrade,
323 {
324 let on_upgrade = self.on_upgrade;
325 let config = self.config;
326 let on_failed_upgrade = self.on_failed_upgrade;
327
328 let protocol = self.protocol.clone();
329
330 tokio::spawn(async move {
331 let upgraded = match on_upgrade.await {
332 Ok(upgraded) => upgraded,
333 Err(err) => {
334 on_failed_upgrade.call(Error::new(err));
335 return;
336 }
337 };
338 let upgraded = TokioIo::new(upgraded);
339
340 let socket =
341 WebSocketStream::from_raw_socket(upgraded, protocol::Role::Server, Some(config))
342 .await;
343 let socket = WebSocket {
344 inner: socket,
345 protocol,
346 };
347 callback(socket).await;
348 });
349
350 let mut response = if let Some(sec_websocket_key) = &self.sec_websocket_key {
351 #[allow(clippy::declare_interior_mutable_const)]
354 const UPGRADE: HeaderValue = HeaderValue::from_static("upgrade");
355 #[allow(clippy::declare_interior_mutable_const)]
356 const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket");
357
358 Response::builder()
359 .status(StatusCode::SWITCHING_PROTOCOLS)
360 .header(header::CONNECTION, UPGRADE)
361 .header(header::UPGRADE, WEBSOCKET)
362 .header(
363 header::SEC_WEBSOCKET_ACCEPT,
364 sign(sec_websocket_key.as_bytes()),
365 )
366 .body(Body::empty())
367 .unwrap()
368 } else {
369 Response::new(Body::empty())
373 };
374
375 if let Some(protocol) = self.protocol {
376 response
377 .headers_mut()
378 .insert(header::SEC_WEBSOCKET_PROTOCOL, protocol);
379 }
380
381 response
382 }
383}
384
385pub trait OnFailedUpgrade: Send + 'static {
389 fn call(self, error: Error);
391}
392
393impl<F> OnFailedUpgrade for F
394where
395 F: FnOnce(Error) + Send + 'static,
396{
397 fn call(self, error: Error) {
398 self(error)
399 }
400}
401
402#[non_exhaustive]
406#[derive(Debug)]
407pub struct DefaultOnFailedUpgrade;
408
409impl OnFailedUpgrade for DefaultOnFailedUpgrade {
410 #[inline]
411 fn call(self, _error: Error) {}
412}
413
414impl<S> FromRequestParts<S> for WebSocketUpgrade<DefaultOnFailedUpgrade>
415where
416 S: Send + Sync,
417{
418 type Rejection = WebSocketUpgradeRejection;
419
420 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
421 let sec_websocket_key = if parts.version <= Version::HTTP_11 {
422 if parts.method != Method::GET {
423 return Err(MethodNotGet.into());
424 }
425
426 if !header_contains(&parts.headers, header::CONNECTION, "upgrade") {
427 return Err(InvalidConnectionHeader.into());
428 }
429
430 if !header_eq(&parts.headers, header::UPGRADE, "websocket") {
431 return Err(InvalidUpgradeHeader.into());
432 }
433
434 Some(
435 parts
436 .headers
437 .get(header::SEC_WEBSOCKET_KEY)
438 .ok_or(WebSocketKeyHeaderMissing)?
439 .clone(),
440 )
441 } else {
442 if parts.method != Method::CONNECT {
443 return Err(MethodNotConnect.into());
444 }
445
446 #[cfg(feature = "http2")]
449 if parts
450 .extensions
451 .get::<hyper::ext::Protocol>()
452 .map_or(true, |p| p.as_str() != "websocket")
453 {
454 return Err(InvalidProtocolPseudoheader.into());
455 }
456
457 None
458 };
459
460 if !header_eq(&parts.headers, header::SEC_WEBSOCKET_VERSION, "13") {
461 return Err(InvalidWebSocketVersionHeader.into());
462 }
463
464 let on_upgrade = parts
465 .extensions
466 .remove::<hyper::upgrade::OnUpgrade>()
467 .ok_or(ConnectionNotUpgradable)?;
468
469 let sec_websocket_protocol = parts.headers.get(header::SEC_WEBSOCKET_PROTOCOL).cloned();
470
471 Ok(Self {
472 config: Default::default(),
473 protocol: None,
474 sec_websocket_key,
475 on_upgrade,
476 sec_websocket_protocol,
477 on_failed_upgrade: DefaultOnFailedUpgrade,
478 })
479 }
480}
481
482fn header_eq(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
483 if let Some(header) = headers.get(&key) {
484 header.as_bytes().eq_ignore_ascii_case(value.as_bytes())
485 } else {
486 false
487 }
488}
489
490fn header_contains(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
491 let header = if let Some(header) = headers.get(&key) {
492 header
493 } else {
494 return false;
495 };
496
497 if let Ok(header) = std::str::from_utf8(header.as_bytes()) {
498 header.to_ascii_lowercase().contains(value)
499 } else {
500 false
501 }
502}
503
504#[derive(Debug)]
508pub struct WebSocket {
509 inner: WebSocketStream<TokioIo<hyper::upgrade::Upgraded>>,
510 protocol: Option<HeaderValue>,
511}
512
513impl WebSocket {
514 pub async fn recv(&mut self) -> Option<Result<Message, Error>> {
518 self.next().await
519 }
520
521 pub async fn send(&mut self, msg: Message) -> Result<(), Error> {
523 self.inner
524 .send(msg.into_tungstenite())
525 .await
526 .map_err(Error::new)
527 }
528
529 pub fn protocol(&self) -> Option<&HeaderValue> {
531 self.protocol.as_ref()
532 }
533}
534
535impl Stream for WebSocket {
536 type Item = Result<Message, Error>;
537
538 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
539 loop {
540 match ready!(self.inner.poll_next_unpin(cx)) {
541 Some(Ok(msg)) => {
542 if let Some(msg) = Message::from_tungstenite(msg) {
543 return Poll::Ready(Some(Ok(msg)));
544 }
545 }
546 Some(Err(err)) => return Poll::Ready(Some(Err(Error::new(err)))),
547 None => return Poll::Ready(None),
548 }
549 }
550 }
551}
552
553impl Sink<Message> for WebSocket {
554 type Error = Error;
555
556 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
557 Pin::new(&mut self.inner).poll_ready(cx).map_err(Error::new)
558 }
559
560 fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
561 Pin::new(&mut self.inner)
562 .start_send(item.into_tungstenite())
563 .map_err(Error::new)
564 }
565
566 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
567 Pin::new(&mut self.inner).poll_flush(cx).map_err(Error::new)
568 }
569
570 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
571 Pin::new(&mut self.inner).poll_close(cx).map_err(Error::new)
572 }
573}
574
575#[derive(Debug, Clone, PartialEq, Eq, Default)]
579pub struct Utf8Bytes(ts::Utf8Bytes);
580
581impl Utf8Bytes {
582 #[inline]
584 pub const fn from_static(str: &'static str) -> Self {
585 Self(ts::Utf8Bytes::from_static(str))
586 }
587
588 #[inline]
590 pub fn as_str(&self) -> &str {
591 self.0.as_str()
592 }
593
594 fn into_tungstenite(self) -> ts::Utf8Bytes {
595 self.0
596 }
597}
598
599impl std::ops::Deref for Utf8Bytes {
600 type Target = str;
601
602 #[inline]
615 fn deref(&self) -> &Self::Target {
616 self.as_str()
617 }
618}
619
620impl std::fmt::Display for Utf8Bytes {
621 #[inline]
622 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
623 f.write_str(self.as_str())
624 }
625}
626
627impl TryFrom<Bytes> for Utf8Bytes {
628 type Error = std::str::Utf8Error;
629
630 #[inline]
631 fn try_from(bytes: Bytes) -> Result<Self, Self::Error> {
632 Ok(Self(bytes.try_into()?))
633 }
634}
635
636impl TryFrom<Vec<u8>> for Utf8Bytes {
637 type Error = std::str::Utf8Error;
638
639 #[inline]
640 fn try_from(v: Vec<u8>) -> Result<Self, Self::Error> {
641 Ok(Self(v.try_into()?))
642 }
643}
644
645impl From<String> for Utf8Bytes {
646 #[inline]
647 fn from(s: String) -> Self {
648 Self(s.into())
649 }
650}
651
652impl From<&str> for Utf8Bytes {
653 #[inline]
654 fn from(s: &str) -> Self {
655 Self(s.into())
656 }
657}
658
659impl From<&String> for Utf8Bytes {
660 #[inline]
661 fn from(s: &String) -> Self {
662 Self(s.into())
663 }
664}
665
666impl From<Utf8Bytes> for Bytes {
667 #[inline]
668 fn from(Utf8Bytes(bytes): Utf8Bytes) -> Self {
669 bytes.into()
670 }
671}
672
673impl<T> PartialEq<T> for Utf8Bytes
674where
675 for<'a> &'a str: PartialEq<T>,
676{
677 #[inline]
685 fn eq(&self, other: &T) -> bool {
686 self.as_str() == *other
687 }
688}
689
690pub type CloseCode = u16;
692
693#[derive(Debug, Clone, Eq, PartialEq)]
695pub struct CloseFrame {
696 pub code: CloseCode,
698 pub reason: Utf8Bytes,
700}
701
702#[derive(Debug, Eq, PartialEq, Clone)]
726pub enum Message {
727 Text(Utf8Bytes),
729 Binary(Bytes),
731 Ping(Bytes),
738 Pong(Bytes),
746 Close(Option<CloseFrame>),
766}
767
768impl Message {
769 fn into_tungstenite(self) -> ts::Message {
770 match self {
771 Self::Text(text) => ts::Message::Text(text.into_tungstenite()),
772 Self::Binary(binary) => ts::Message::Binary(binary),
773 Self::Ping(ping) => ts::Message::Ping(ping),
774 Self::Pong(pong) => ts::Message::Pong(pong),
775 Self::Close(Some(close)) => ts::Message::Close(Some(ts::protocol::CloseFrame {
776 code: ts::protocol::frame::coding::CloseCode::from(close.code),
777 reason: close.reason.into_tungstenite(),
778 })),
779 Self::Close(None) => ts::Message::Close(None),
780 }
781 }
782
783 fn from_tungstenite(message: ts::Message) -> Option<Self> {
784 match message {
785 ts::Message::Text(text) => Some(Self::Text(Utf8Bytes(text))),
786 ts::Message::Binary(binary) => Some(Self::Binary(binary)),
787 ts::Message::Ping(ping) => Some(Self::Ping(ping)),
788 ts::Message::Pong(pong) => Some(Self::Pong(pong)),
789 ts::Message::Close(Some(close)) => Some(Self::Close(Some(CloseFrame {
790 code: close.code.into(),
791 reason: Utf8Bytes(close.reason),
792 }))),
793 ts::Message::Close(None) => Some(Self::Close(None)),
794 ts::Message::Frame(_) => None,
797 }
798 }
799
800 pub fn into_data(self) -> Bytes {
802 match self {
803 Self::Text(string) => Bytes::from(string),
804 Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => data,
805 Self::Close(None) => Bytes::new(),
806 Self::Close(Some(frame)) => Bytes::from(frame.reason),
807 }
808 }
809
810 pub fn into_text(self) -> Result<Utf8Bytes, Error> {
812 match self {
813 Self::Text(string) => Ok(string),
814 Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => {
815 Ok(Utf8Bytes::try_from(data).map_err(Error::new)?)
816 }
817 Self::Close(None) => Ok(Utf8Bytes::default()),
818 Self::Close(Some(frame)) => Ok(frame.reason),
819 }
820 }
821
822 pub fn to_text(&self) -> Result<&str, Error> {
825 match *self {
826 Self::Text(ref string) => Ok(string.as_str()),
827 Self::Binary(ref data) | Self::Ping(ref data) | Self::Pong(ref data) => {
828 Ok(std::str::from_utf8(data).map_err(Error::new)?)
829 }
830 Self::Close(None) => Ok(""),
831 Self::Close(Some(ref frame)) => Ok(&frame.reason),
832 }
833 }
834
835 pub fn text<S>(string: S) -> Message
837 where
838 S: Into<Utf8Bytes>,
839 {
840 Message::Text(string.into())
841 }
842
843 pub fn binary<B>(bin: B) -> Message
845 where
846 B: Into<Bytes>,
847 {
848 Message::Binary(bin.into())
849 }
850}
851
852impl From<String> for Message {
853 fn from(string: String) -> Self {
854 Message::Text(string.into())
855 }
856}
857
858impl<'s> From<&'s str> for Message {
859 fn from(string: &'s str) -> Self {
860 Message::Text(string.into())
861 }
862}
863
864impl<'b> From<&'b [u8]> for Message {
865 fn from(data: &'b [u8]) -> Self {
866 Message::Binary(Bytes::copy_from_slice(data))
867 }
868}
869
870impl From<Bytes> for Message {
871 fn from(data: Bytes) -> Self {
872 Message::Binary(data)
873 }
874}
875
876impl From<Vec<u8>> for Message {
877 fn from(data: Vec<u8>) -> Self {
878 Message::Binary(data.into())
879 }
880}
881
882impl From<Message> for Vec<u8> {
883 fn from(msg: Message) -> Self {
884 msg.into_data().to_vec()
885 }
886}
887
888fn sign(key: &[u8]) -> HeaderValue {
889 use base64::engine::Engine as _;
890
891 let mut sha1 = Sha1::default();
892 sha1.update(key);
893 sha1.update(&b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"[..]);
894 let b64 = Bytes::from(base64::engine::general_purpose::STANDARD.encode(sha1.finalize()));
895 HeaderValue::from_maybe_shared(b64).expect("base64 is a valid value")
896}
897
898pub mod rejection {
899 use axum_core::__composite_rejection as composite_rejection;
902 use axum_core::__define_rejection as define_rejection;
903
904 define_rejection! {
905 #[status = METHOD_NOT_ALLOWED]
906 #[body = "Request method must be `GET`"]
907 pub struct MethodNotGet;
909 }
910
911 define_rejection! {
912 #[status = METHOD_NOT_ALLOWED]
913 #[body = "Request method must be `CONNECT`"]
914 pub struct MethodNotConnect;
916 }
917
918 define_rejection! {
919 #[status = BAD_REQUEST]
920 #[body = "Connection header did not include 'upgrade'"]
921 pub struct InvalidConnectionHeader;
923 }
924
925 define_rejection! {
926 #[status = BAD_REQUEST]
927 #[body = "`Upgrade` header did not include 'websocket'"]
928 pub struct InvalidUpgradeHeader;
930 }
931
932 define_rejection! {
933 #[status = BAD_REQUEST]
934 #[body = "`:protocol` pseudo-header did not include 'websocket'"]
935 pub struct InvalidProtocolPseudoheader;
937 }
938
939 define_rejection! {
940 #[status = BAD_REQUEST]
941 #[body = "`Sec-WebSocket-Version` header did not include '13'"]
942 pub struct InvalidWebSocketVersionHeader;
944 }
945
946 define_rejection! {
947 #[status = BAD_REQUEST]
948 #[body = "`Sec-WebSocket-Key` header missing"]
949 pub struct WebSocketKeyHeaderMissing;
951 }
952
953 define_rejection! {
954 #[status = UPGRADE_REQUIRED]
955 #[body = "WebSocket request couldn't be upgraded since no upgrade state was present"]
956 pub struct ConnectionNotUpgradable;
965 }
966
967 composite_rejection! {
968 pub enum WebSocketUpgradeRejection {
973 MethodNotGet,
974 MethodNotConnect,
975 InvalidConnectionHeader,
976 InvalidUpgradeHeader,
977 InvalidProtocolPseudoheader,
978 InvalidWebSocketVersionHeader,
979 WebSocketKeyHeaderMissing,
980 ConnectionNotUpgradable,
981 }
982 }
983}
984
985pub mod close_code {
986 pub const NORMAL: u16 = 1000;
993
994 pub const AWAY: u16 = 1001;
997
998 pub const PROTOCOL: u16 = 1002;
1000
1001 pub const UNSUPPORTED: u16 = 1003;
1006
1007 pub const STATUS: u16 = 1005;
1009
1010 pub const ABNORMAL: u16 = 1006;
1012
1013 pub const INVALID: u16 = 1007;
1018
1019 pub const POLICY: u16 = 1008;
1026
1027 pub const SIZE: u16 = 1009;
1030
1031 pub const EXTENSION: u16 = 1010;
1040
1041 pub const ERROR: u16 = 1011;
1044
1045 pub const RESTART: u16 = 1012;
1047
1048 pub const AGAIN: u16 = 1013;
1052}
1053
1054#[cfg(test)]
1055mod tests {
1056 use std::future::ready;
1057
1058 use super::*;
1059 use crate::{routing::any, test_helpers::spawn_service, Router};
1060 use http::{Request, Version};
1061 use http_body_util::BodyExt as _;
1062 use hyper_util::rt::TokioExecutor;
1063 use tokio::io::{AsyncRead, AsyncWrite};
1064 use tokio::net::TcpStream;
1065 use tokio_tungstenite::tungstenite;
1066 use tower::ServiceExt;
1067
1068 #[crate::test]
1069 async fn rejects_http_1_0_requests() {
1070 let svc = any(|ws: Result<WebSocketUpgrade, WebSocketUpgradeRejection>| {
1071 let rejection = ws.unwrap_err();
1072 assert!(matches!(
1073 rejection,
1074 WebSocketUpgradeRejection::ConnectionNotUpgradable(_)
1075 ));
1076 std::future::ready(())
1077 });
1078
1079 let req = Request::builder()
1080 .version(Version::HTTP_10)
1081 .method(Method::GET)
1082 .header("upgrade", "websocket")
1083 .header("connection", "Upgrade")
1084 .header("sec-websocket-key", "6D69KGBOr4Re+Nj6zx9aQA==")
1085 .header("sec-websocket-version", "13")
1086 .body(Body::empty())
1087 .unwrap();
1088
1089 let res = svc.oneshot(req).await.unwrap();
1090
1091 assert_eq!(res.status(), StatusCode::OK);
1092 }
1093
1094 #[allow(dead_code)]
1095 fn default_on_failed_upgrade() {
1096 async fn handler(ws: WebSocketUpgrade) -> Response {
1097 ws.on_upgrade(|_| async {})
1098 }
1099 let _: Router = Router::new().route("/", any(handler));
1100 }
1101
1102 #[allow(dead_code)]
1103 fn on_failed_upgrade() {
1104 async fn handler(ws: WebSocketUpgrade) -> Response {
1105 ws.on_failed_upgrade(|_error: Error| println!("oops!"))
1106 .on_upgrade(|_| async {})
1107 }
1108 let _: Router = Router::new().route("/", any(handler));
1109 }
1110
1111 #[crate::test]
1112 async fn integration_test() {
1113 let addr = spawn_service(echo_app());
1114 let uri = format!("ws://{addr}/echo").try_into().unwrap();
1115 let req = tungstenite::client::ClientRequestBuilder::new(uri)
1116 .with_sub_protocol(TEST_ECHO_APP_REQ_SUBPROTO);
1117 let (socket, response) = tokio_tungstenite::connect_async(req).await.unwrap();
1118 test_echo_app(socket, response.headers()).await;
1119 }
1120
1121 #[crate::test]
1122 #[cfg(feature = "http2")]
1123 async fn http2() {
1124 let addr = spawn_service(echo_app());
1125 let io = TokioIo::new(TcpStream::connect(addr).await.unwrap());
1126 let (mut send_request, conn) =
1127 hyper::client::conn::http2::Builder::new(TokioExecutor::new())
1128 .handshake(io)
1129 .await
1130 .unwrap();
1131
1132 for _ in 0..10 {
1134 tokio::task::yield_now().await;
1135 }
1136 assert!(conn.is_extended_connect_protocol_enabled());
1137 tokio::spawn(async {
1138 conn.await.unwrap();
1139 });
1140
1141 let req = Request::builder()
1142 .method(Method::CONNECT)
1143 .extension(hyper::ext::Protocol::from_static("websocket"))
1144 .uri("/echo")
1145 .header("sec-websocket-version", "13")
1146 .header("sec-websocket-protocol", TEST_ECHO_APP_REQ_SUBPROTO)
1147 .header("Host", "server.example.com")
1148 .body(Body::empty())
1149 .unwrap();
1150
1151 let mut response = send_request.send_request(req).await.unwrap();
1152 let status = response.status();
1153 if status != 200 {
1154 let body = response.into_body().collect().await.unwrap().to_bytes();
1155 let body = std::str::from_utf8(&body).unwrap();
1156 panic!("response status was {status}: {body}");
1157 }
1158 let upgraded = hyper::upgrade::on(&mut response).await.unwrap();
1159 let upgraded = TokioIo::new(upgraded);
1160 let socket = WebSocketStream::from_raw_socket(upgraded, protocol::Role::Client, None).await;
1161 test_echo_app(socket, response.headers()).await;
1162 }
1163
1164 fn echo_app() -> Router {
1165 async fn handle_socket(mut socket: WebSocket) {
1166 assert_eq!(socket.protocol().unwrap(), "echo");
1167 while let Some(Ok(msg)) = socket.recv().await {
1168 match msg {
1169 Message::Text(_) | Message::Binary(_) | Message::Close(_) => {
1170 if socket.send(msg).await.is_err() {
1171 break;
1172 }
1173 }
1174 Message::Ping(_) | Message::Pong(_) => {
1175 }
1177 }
1178 }
1179 }
1180
1181 Router::new().route(
1182 "/echo",
1183 any(|ws: WebSocketUpgrade| {
1184 let ws = ws.protocols(["echo2", "echo"]);
1185 assert_eq!(ws.selected_protocol().unwrap(), "echo");
1186 ready(ws.on_upgrade(handle_socket))
1187 }),
1188 )
1189 }
1190
1191 const TEST_ECHO_APP_REQ_SUBPROTO: &str = "echo3, echo";
1192 async fn test_echo_app<S: AsyncRead + AsyncWrite + Unpin>(
1193 mut socket: WebSocketStream<S>,
1194 headers: &http::HeaderMap,
1195 ) {
1196 assert_eq!(headers[http::header::SEC_WEBSOCKET_PROTOCOL], "echo");
1197
1198 let input = tungstenite::Message::Text(tungstenite::Utf8Bytes::from_static("foobar"));
1199 socket.send(input.clone()).await.unwrap();
1200 let output = socket.next().await.unwrap().unwrap();
1201 assert_eq!(input, output);
1202
1203 socket
1204 .send(tungstenite::Message::Ping(Bytes::from_static(b"ping")))
1205 .await
1206 .unwrap();
1207 let output = socket.next().await.unwrap().unwrap();
1208 assert_eq!(
1209 output,
1210 tungstenite::Message::Pong(Bytes::from_static(b"ping"))
1211 );
1212 }
1213}