1use self::rejection::*;
94use super::FromRequestParts;
95use crate::{body::Bytes, response::Response, Error};
96use axum_core::body::Body;
97use futures_util::{
98 sink::{Sink, SinkExt},
99 stream::{FusedStream, Stream, StreamExt},
100};
101use http::{
102 header::{self, HeaderMap, HeaderName, HeaderValue},
103 request::Parts,
104 Method, StatusCode, Version,
105};
106use hyper_util::rt::TokioIo;
107use sha1::{Digest, Sha1};
108use std::{
109 borrow::Cow,
110 collections::BTreeSet,
111 future::Future,
112 pin::Pin,
113 str,
114 task::{ready, Context, Poll},
115};
116use tokio_tungstenite::{
117 tungstenite::{
118 self as ts,
119 protocol::{self, WebSocketConfig},
120 },
121 WebSocketStream,
122};
123
124#[cfg_attr(docsrs, doc(cfg(feature = "ws")))]
134#[must_use]
135pub struct WebSocketUpgrade<F = DefaultOnFailedUpgrade> {
136 config: WebSocketConfig,
137 protocol: Option<HeaderValue>,
139 sec_websocket_key: Option<HeaderValue>,
141 on_upgrade: hyper::upgrade::OnUpgrade,
142 on_failed_upgrade: F,
143 sec_websocket_protocol: BTreeSet<HeaderValue>,
144}
145
146impl<F> std::fmt::Debug for WebSocketUpgrade<F> {
147 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
148 f.debug_struct("WebSocketUpgrade")
149 .field("config", &self.config)
150 .field("protocol", &self.protocol)
151 .field("sec_websocket_key", &self.sec_websocket_key)
152 .field("sec_websocket_protocol", &self.sec_websocket_protocol)
153 .finish_non_exhaustive()
154 }
155}
156
157impl<F> WebSocketUpgrade<F> {
158 pub fn read_buffer_size(mut self, size: usize) -> Self {
160 self.config.read_buffer_size = size;
161 self
162 }
163
164 pub fn write_buffer_size(mut self, size: usize) -> Self {
174 self.config.write_buffer_size = size;
175 self
176 }
177
178 pub fn max_write_buffer_size(mut self, max: usize) -> Self {
190 self.config.max_write_buffer_size = max;
191 self
192 }
193
194 pub fn max_message_size(mut self, max: usize) -> Self {
196 self.config.max_message_size = Some(max);
197 self
198 }
199
200 pub fn max_frame_size(mut self, max: usize) -> Self {
202 self.config.max_frame_size = Some(max);
203 self
204 }
205
206 pub fn accept_unmasked_frames(mut self, accept: bool) -> Self {
208 self.config.accept_unmasked_frames = accept;
209 self
210 }
211
212 pub fn protocols<I>(mut self, protocols: I) -> Self
243 where
244 I: IntoIterator,
245 I::Item: Into<Cow<'static, str>>,
246 {
247 self.protocol = protocols
248 .into_iter()
249 .map(Into::into)
250 .find(|proto| {
251 let Ok(proto) = HeaderValue::from_str(proto) else {
256 return false;
257 };
258 self.sec_websocket_protocol.contains(&proto)
259 })
260 .map(|protocol| match protocol {
261 Cow::Owned(s) => HeaderValue::from_str(&s).unwrap(),
262 Cow::Borrowed(s) => HeaderValue::from_static(s),
263 });
264
265 self
266 }
267
268 pub fn requested_protocols(&self) -> impl Iterator<Item = &HeaderValue> {
280 self.sec_websocket_protocol.iter()
281 }
282
283 pub fn set_selected_protocol(&mut self, protocol: HeaderValue) {
294 self.protocol = Some(protocol);
295 }
296
297 pub fn selected_protocol(&self) -> Option<&HeaderValue> {
303 self.protocol.as_ref()
304 }
305
306 pub fn on_failed_upgrade<C>(self, callback: C) -> WebSocketUpgrade<C>
331 where
332 C: OnFailedUpgrade,
333 {
334 WebSocketUpgrade {
335 config: self.config,
336 protocol: self.protocol,
337 sec_websocket_key: self.sec_websocket_key,
338 on_upgrade: self.on_upgrade,
339 on_failed_upgrade: callback,
340 sec_websocket_protocol: self.sec_websocket_protocol,
341 }
342 }
343
344 #[must_use = "to set up the WebSocket connection, this response must be returned"]
347 pub fn on_upgrade<C, Fut>(self, callback: C) -> Response
348 where
349 C: FnOnce(WebSocket) -> Fut + Send + 'static,
350 Fut: Future<Output = ()> + Send + 'static,
351 F: OnFailedUpgrade,
352 {
353 let on_upgrade = self.on_upgrade;
354 let config = self.config;
355 let on_failed_upgrade = self.on_failed_upgrade;
356
357 let protocol = self.protocol.clone();
358
359 tokio::spawn(async move {
360 let upgraded = match on_upgrade.await {
361 Ok(upgraded) => upgraded,
362 Err(err) => {
363 on_failed_upgrade.call(Error::new(err));
364 return;
365 }
366 };
367 let upgraded = TokioIo::new(upgraded);
368
369 let socket =
370 WebSocketStream::from_raw_socket(upgraded, protocol::Role::Server, Some(config))
371 .await;
372 let socket = WebSocket {
373 inner: socket,
374 protocol,
375 };
376 callback(socket).await;
377 });
378
379 let mut response = if let Some(sec_websocket_key) = &self.sec_websocket_key {
380 #[allow(clippy::declare_interior_mutable_const)]
383 const UPGRADE: HeaderValue = HeaderValue::from_static("upgrade");
384 #[allow(clippy::declare_interior_mutable_const)]
385 const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket");
386
387 Response::builder()
388 .status(StatusCode::SWITCHING_PROTOCOLS)
389 .header(header::CONNECTION, UPGRADE)
390 .header(header::UPGRADE, WEBSOCKET)
391 .header(
392 header::SEC_WEBSOCKET_ACCEPT,
393 sign(sec_websocket_key.as_bytes()),
394 )
395 .body(Body::empty())
396 .unwrap()
397 } else {
398 Response::new(Body::empty())
402 };
403
404 if let Some(protocol) = self.protocol {
405 response
406 .headers_mut()
407 .insert(header::SEC_WEBSOCKET_PROTOCOL, protocol);
408 }
409
410 response
411 }
412}
413
414pub trait OnFailedUpgrade: Send + 'static {
418 fn call(self, error: Error);
420}
421
422impl<F> OnFailedUpgrade for F
423where
424 F: FnOnce(Error) + Send + 'static,
425{
426 fn call(self, error: Error) {
427 self(error)
428 }
429}
430
431#[non_exhaustive]
435#[derive(Debug)]
436pub struct DefaultOnFailedUpgrade;
437
438impl OnFailedUpgrade for DefaultOnFailedUpgrade {
439 #[inline]
440 fn call(self, _error: Error) {}
441}
442
443impl<S> FromRequestParts<S> for WebSocketUpgrade<DefaultOnFailedUpgrade>
444where
445 S: Send + Sync,
446{
447 type Rejection = WebSocketUpgradeRejection;
448
449 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
450 let sec_websocket_key = if parts.version <= Version::HTTP_11 {
451 if parts.method != Method::GET {
452 return Err(MethodNotGet.into());
453 }
454
455 if !header_contains(&parts.headers, header::CONNECTION, "upgrade") {
456 return Err(InvalidConnectionHeader.into());
457 }
458
459 if !header_eq(&parts.headers, header::UPGRADE, "websocket") {
460 return Err(InvalidUpgradeHeader.into());
461 }
462
463 Some(
464 parts
465 .headers
466 .get(header::SEC_WEBSOCKET_KEY)
467 .ok_or(WebSocketKeyHeaderMissing)?
468 .clone(),
469 )
470 } else {
471 if parts.method != Method::CONNECT {
472 return Err(MethodNotConnect.into());
473 }
474
475 #[cfg(feature = "http2")]
478 if parts
479 .extensions
480 .get::<hyper::ext::Protocol>()
481 .map_or(true, |p| p.as_str() != "websocket")
482 {
483 return Err(InvalidProtocolPseudoheader.into());
484 }
485
486 None
487 };
488
489 if !header_eq(&parts.headers, header::SEC_WEBSOCKET_VERSION, "13") {
490 return Err(InvalidWebSocketVersionHeader.into());
491 }
492
493 let on_upgrade = parts
494 .extensions
495 .remove::<hyper::upgrade::OnUpgrade>()
496 .ok_or(ConnectionNotUpgradable)?;
497
498 let sec_websocket_protocol = parts
499 .headers
500 .get_all(header::SEC_WEBSOCKET_PROTOCOL)
501 .iter()
502 .flat_map(|val| val.as_bytes().split(|&b| b == b','))
503 .map(|proto| {
504 HeaderValue::from_bytes(proto.trim_ascii())
505 .expect("substring of HeaderValue is valid HeaderValue")
506 })
507 .collect();
508
509 Ok(Self {
510 config: Default::default(),
511 protocol: None,
512 sec_websocket_key,
513 on_upgrade,
514 sec_websocket_protocol,
515 on_failed_upgrade: DefaultOnFailedUpgrade,
516 })
517 }
518}
519
520fn header_eq(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
521 if let Some(header) = headers.get(&key) {
522 header.as_bytes().eq_ignore_ascii_case(value.as_bytes())
523 } else {
524 false
525 }
526}
527
528fn header_contains(headers: &HeaderMap, key: HeaderName, value: &'static str) -> bool {
529 let header = if let Some(header) = headers.get(&key) {
530 header
531 } else {
532 return false;
533 };
534
535 if let Ok(header) = std::str::from_utf8(header.as_bytes()) {
536 header.to_ascii_lowercase().contains(value)
537 } else {
538 false
539 }
540}
541
542#[derive(Debug)]
546pub struct WebSocket {
547 inner: WebSocketStream<TokioIo<hyper::upgrade::Upgraded>>,
548 protocol: Option<HeaderValue>,
549}
550
551impl WebSocket {
552 pub async fn recv(&mut self) -> Option<Result<Message, Error>> {
556 self.next().await
557 }
558
559 pub async fn send(&mut self, msg: Message) -> Result<(), Error> {
561 self.inner
562 .send(msg.into_tungstenite())
563 .await
564 .map_err(Error::new)
565 }
566
567 pub fn protocol(&self) -> Option<&HeaderValue> {
569 self.protocol.as_ref()
570 }
571}
572
573impl FusedStream for WebSocket {
574 fn is_terminated(&self) -> bool {
576 self.inner.is_terminated()
577 }
578}
579
580impl Stream for WebSocket {
581 type Item = Result<Message, Error>;
582
583 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
584 loop {
585 match ready!(self.inner.poll_next_unpin(cx)) {
586 Some(Ok(msg)) => {
587 if let Some(msg) = Message::from_tungstenite(msg) {
588 return Poll::Ready(Some(Ok(msg)));
589 }
590 }
591 Some(Err(err)) => return Poll::Ready(Some(Err(Error::new(err)))),
592 None => return Poll::Ready(None),
593 }
594 }
595 }
596}
597
598impl Sink<Message> for WebSocket {
599 type Error = Error;
600
601 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
602 Pin::new(&mut self.inner).poll_ready(cx).map_err(Error::new)
603 }
604
605 fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
606 Pin::new(&mut self.inner)
607 .start_send(item.into_tungstenite())
608 .map_err(Error::new)
609 }
610
611 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
612 Pin::new(&mut self.inner).poll_flush(cx).map_err(Error::new)
613 }
614
615 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
616 Pin::new(&mut self.inner).poll_close(cx).map_err(Error::new)
617 }
618}
619
620#[derive(Debug, Clone, PartialEq, Eq, Default)]
624pub struct Utf8Bytes(ts::Utf8Bytes);
625
626impl Utf8Bytes {
627 #[inline]
629 #[must_use]
630 pub const fn from_static(str: &'static str) -> Self {
631 Self(ts::Utf8Bytes::from_static(str))
632 }
633
634 #[inline]
636 pub fn as_str(&self) -> &str {
637 self.0.as_str()
638 }
639
640 fn into_tungstenite(self) -> ts::Utf8Bytes {
641 self.0
642 }
643}
644
645impl std::ops::Deref for Utf8Bytes {
646 type Target = str;
647
648 #[inline]
661 fn deref(&self) -> &Self::Target {
662 self.as_str()
663 }
664}
665
666impl std::fmt::Display for Utf8Bytes {
667 #[inline]
668 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
669 f.write_str(self.as_str())
670 }
671}
672
673impl TryFrom<Bytes> for Utf8Bytes {
674 type Error = std::str::Utf8Error;
675
676 #[inline]
677 fn try_from(bytes: Bytes) -> Result<Self, Self::Error> {
678 Ok(Self(bytes.try_into()?))
679 }
680}
681
682impl TryFrom<Vec<u8>> for Utf8Bytes {
683 type Error = std::str::Utf8Error;
684
685 #[inline]
686 fn try_from(v: Vec<u8>) -> Result<Self, Self::Error> {
687 Ok(Self(v.try_into()?))
688 }
689}
690
691impl From<String> for Utf8Bytes {
692 #[inline]
693 fn from(s: String) -> Self {
694 Self(s.into())
695 }
696}
697
698impl From<&str> for Utf8Bytes {
699 #[inline]
700 fn from(s: &str) -> Self {
701 Self(s.into())
702 }
703}
704
705impl From<&String> for Utf8Bytes {
706 #[inline]
707 fn from(s: &String) -> Self {
708 Self(s.into())
709 }
710}
711
712impl From<Utf8Bytes> for Bytes {
713 #[inline]
714 fn from(Utf8Bytes(bytes): Utf8Bytes) -> Self {
715 bytes.into()
716 }
717}
718
719impl<T> PartialEq<T> for Utf8Bytes
720where
721 for<'a> &'a str: PartialEq<T>,
722{
723 #[inline]
731 fn eq(&self, other: &T) -> bool {
732 self.as_str() == *other
733 }
734}
735
736pub type CloseCode = u16;
738
739#[derive(Debug, Clone, Eq, PartialEq)]
741pub struct CloseFrame {
742 pub code: CloseCode,
744 pub reason: Utf8Bytes,
746}
747
748#[derive(Debug, Eq, PartialEq, Clone)]
772pub enum Message {
773 Text(Utf8Bytes),
775 Binary(Bytes),
777 Ping(Bytes),
784 Pong(Bytes),
792 Close(Option<CloseFrame>),
812}
813
814impl Message {
815 fn into_tungstenite(self) -> ts::Message {
816 match self {
817 Self::Text(text) => ts::Message::Text(text.into_tungstenite()),
818 Self::Binary(binary) => ts::Message::Binary(binary),
819 Self::Ping(ping) => ts::Message::Ping(ping),
820 Self::Pong(pong) => ts::Message::Pong(pong),
821 Self::Close(Some(close)) => ts::Message::Close(Some(ts::protocol::CloseFrame {
822 code: ts::protocol::frame::coding::CloseCode::from(close.code),
823 reason: close.reason.into_tungstenite(),
824 })),
825 Self::Close(None) => ts::Message::Close(None),
826 }
827 }
828
829 fn from_tungstenite(message: ts::Message) -> Option<Self> {
830 match message {
831 ts::Message::Text(text) => Some(Self::Text(Utf8Bytes(text))),
832 ts::Message::Binary(binary) => Some(Self::Binary(binary)),
833 ts::Message::Ping(ping) => Some(Self::Ping(ping)),
834 ts::Message::Pong(pong) => Some(Self::Pong(pong)),
835 ts::Message::Close(Some(close)) => Some(Self::Close(Some(CloseFrame {
836 code: close.code.into(),
837 reason: Utf8Bytes(close.reason),
838 }))),
839 ts::Message::Close(None) => Some(Self::Close(None)),
840 ts::Message::Frame(_) => None,
843 }
844 }
845
846 pub fn into_data(self) -> Bytes {
848 match self {
849 Self::Text(string) => Bytes::from(string),
850 Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => data,
851 Self::Close(None) => Bytes::new(),
852 Self::Close(Some(frame)) => Bytes::from(frame.reason),
853 }
854 }
855
856 pub fn into_text(self) -> Result<Utf8Bytes, Error> {
858 match self {
859 Self::Text(string) => Ok(string),
860 Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => {
861 Ok(Utf8Bytes::try_from(data).map_err(Error::new)?)
862 }
863 Self::Close(None) => Ok(Utf8Bytes::default()),
864 Self::Close(Some(frame)) => Ok(frame.reason),
865 }
866 }
867
868 pub fn to_text(&self) -> Result<&str, Error> {
871 match *self {
872 Self::Text(ref string) => Ok(string.as_str()),
873 Self::Binary(ref data) | Self::Ping(ref data) | Self::Pong(ref data) => {
874 Ok(std::str::from_utf8(data).map_err(Error::new)?)
875 }
876 Self::Close(None) => Ok(""),
877 Self::Close(Some(ref frame)) => Ok(&frame.reason),
878 }
879 }
880
881 pub fn text<S>(string: S) -> Message
883 where
884 S: Into<Utf8Bytes>,
885 {
886 Message::Text(string.into())
887 }
888
889 pub fn binary<B>(bin: B) -> Message
891 where
892 B: Into<Bytes>,
893 {
894 Message::Binary(bin.into())
895 }
896}
897
898impl From<String> for Message {
899 fn from(string: String) -> Self {
900 Message::Text(string.into())
901 }
902}
903
904impl<'s> From<&'s str> for Message {
905 fn from(string: &'s str) -> Self {
906 Message::Text(string.into())
907 }
908}
909
910impl<'b> From<&'b [u8]> for Message {
911 fn from(data: &'b [u8]) -> Self {
912 Message::Binary(Bytes::copy_from_slice(data))
913 }
914}
915
916impl From<Bytes> for Message {
917 fn from(data: Bytes) -> Self {
918 Message::Binary(data)
919 }
920}
921
922impl From<Vec<u8>> for Message {
923 fn from(data: Vec<u8>) -> Self {
924 Message::Binary(data.into())
925 }
926}
927
928impl From<Message> for Vec<u8> {
929 fn from(msg: Message) -> Self {
930 msg.into_data().to_vec()
931 }
932}
933
934fn sign(key: &[u8]) -> HeaderValue {
935 use base64::engine::Engine as _;
936
937 let mut sha1 = Sha1::default();
938 sha1.update(key);
939 sha1.update(&b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"[..]);
940 let b64 = Bytes::from(base64::engine::general_purpose::STANDARD.encode(sha1.finalize()));
941 HeaderValue::from_maybe_shared(b64).expect("base64 is a valid value")
942}
943
944pub mod rejection {
945 use axum_core::__composite_rejection as composite_rejection;
948 use axum_core::__define_rejection as define_rejection;
949
950 define_rejection! {
951 #[status = METHOD_NOT_ALLOWED]
952 #[body = "Request method must be `GET`"]
953 pub struct MethodNotGet;
955 }
956
957 define_rejection! {
958 #[status = METHOD_NOT_ALLOWED]
959 #[body = "Request method must be `CONNECT`"]
960 pub struct MethodNotConnect;
962 }
963
964 define_rejection! {
965 #[status = BAD_REQUEST]
966 #[body = "Connection header did not include 'upgrade'"]
967 pub struct InvalidConnectionHeader;
969 }
970
971 define_rejection! {
972 #[status = BAD_REQUEST]
973 #[body = "`Upgrade` header did not include 'websocket'"]
974 pub struct InvalidUpgradeHeader;
976 }
977
978 define_rejection! {
979 #[status = BAD_REQUEST]
980 #[body = "`:protocol` pseudo-header did not include 'websocket'"]
981 pub struct InvalidProtocolPseudoheader;
983 }
984
985 define_rejection! {
986 #[status = BAD_REQUEST]
987 #[body = "`Sec-WebSocket-Version` header did not include '13'"]
988 pub struct InvalidWebSocketVersionHeader;
990 }
991
992 define_rejection! {
993 #[status = BAD_REQUEST]
994 #[body = "`Sec-WebSocket-Key` header missing"]
995 pub struct WebSocketKeyHeaderMissing;
997 }
998
999 define_rejection! {
1000 #[status = UPGRADE_REQUIRED]
1001 #[body = "WebSocket request couldn't be upgraded since no upgrade state was present"]
1002 pub struct ConnectionNotUpgradable;
1011 }
1012
1013 composite_rejection! {
1014 pub enum WebSocketUpgradeRejection {
1019 MethodNotGet,
1020 MethodNotConnect,
1021 InvalidConnectionHeader,
1022 InvalidUpgradeHeader,
1023 InvalidProtocolPseudoheader,
1024 InvalidWebSocketVersionHeader,
1025 WebSocketKeyHeaderMissing,
1026 ConnectionNotUpgradable,
1027 }
1028 }
1029}
1030
1031pub mod close_code {
1032 pub const NORMAL: u16 = 1000;
1039
1040 pub const AWAY: u16 = 1001;
1043
1044 pub const PROTOCOL: u16 = 1002;
1046
1047 pub const UNSUPPORTED: u16 = 1003;
1052
1053 pub const STATUS: u16 = 1005;
1055
1056 pub const ABNORMAL: u16 = 1006;
1058
1059 pub const INVALID: u16 = 1007;
1064
1065 pub const POLICY: u16 = 1008;
1072
1073 pub const SIZE: u16 = 1009;
1076
1077 pub const EXTENSION: u16 = 1010;
1086
1087 pub const ERROR: u16 = 1011;
1090
1091 pub const RESTART: u16 = 1012;
1093
1094 pub const AGAIN: u16 = 1013;
1098}
1099
1100#[cfg(test)]
1101mod tests {
1102 use std::future::ready;
1103
1104 use super::*;
1105 use crate::{routing::any, test_helpers::spawn_service, Router};
1106 use http::{Request, Version};
1107 use http_body_util::BodyExt as _;
1108 use hyper_util::rt::TokioExecutor;
1109 use tokio::io::{AsyncRead, AsyncWrite};
1110 use tokio::net::TcpStream;
1111 use tokio_tungstenite::tungstenite;
1112 use tower::ServiceExt;
1113
1114 #[crate::test]
1115 async fn rejects_http_1_0_requests() {
1116 let svc = any(|ws: Result<WebSocketUpgrade, WebSocketUpgradeRejection>| {
1117 let rejection = ws.unwrap_err();
1118 assert!(matches!(
1119 rejection,
1120 WebSocketUpgradeRejection::ConnectionNotUpgradable(_)
1121 ));
1122 std::future::ready(())
1123 });
1124
1125 let req = Request::builder()
1126 .version(Version::HTTP_10)
1127 .method(Method::GET)
1128 .header("upgrade", "websocket")
1129 .header("connection", "Upgrade")
1130 .header("sec-websocket-key", "6D69KGBOr4Re+Nj6zx9aQA==")
1131 .header("sec-websocket-version", "13")
1132 .body(Body::empty())
1133 .unwrap();
1134
1135 let res = svc.oneshot(req).await.unwrap();
1136
1137 assert_eq!(res.status(), StatusCode::OK);
1138 }
1139
1140 #[allow(dead_code)]
1141 fn default_on_failed_upgrade() {
1142 async fn handler(ws: WebSocketUpgrade) -> Response {
1143 ws.on_upgrade(|_| async {})
1144 }
1145 let _: Router = Router::new().route("/", any(handler));
1146 }
1147
1148 #[allow(dead_code)]
1149 fn on_failed_upgrade() {
1150 async fn handler(ws: WebSocketUpgrade) -> Response {
1151 ws.on_failed_upgrade(|_error: Error| println!("oops!"))
1152 .on_upgrade(|_| async {})
1153 }
1154 let _: Router = Router::new().route("/", any(handler));
1155 }
1156
1157 #[crate::test]
1158 async fn integration_test() {
1159 let addr = spawn_service(echo_app());
1160 let uri = format!("ws://{addr}/echo").try_into().unwrap();
1161 let req = tungstenite::client::ClientRequestBuilder::new(uri)
1162 .with_sub_protocol(TEST_ECHO_APP_REQ_SUBPROTO);
1163 let (socket, response) = tokio_tungstenite::connect_async(req).await.unwrap();
1164 test_echo_app(socket, response.headers()).await;
1165 }
1166
1167 #[crate::test]
1168 #[cfg(feature = "http2")]
1169 async fn http2() {
1170 let addr = spawn_service(echo_app());
1171 let io = TokioIo::new(TcpStream::connect(addr).await.unwrap());
1172 let (mut send_request, conn) =
1173 hyper::client::conn::http2::Builder::new(TokioExecutor::new())
1174 .handshake(io)
1175 .await
1176 .unwrap();
1177
1178 for _ in 0..10 {
1180 tokio::task::yield_now().await;
1181 }
1182 assert!(conn.is_extended_connect_protocol_enabled());
1183 tokio::spawn(async {
1184 conn.await.unwrap();
1185 });
1186
1187 let req = Request::builder()
1188 .method(Method::CONNECT)
1189 .extension(hyper::ext::Protocol::from_static("websocket"))
1190 .uri("/echo")
1191 .header("sec-websocket-version", "13")
1192 .header("sec-websocket-protocol", TEST_ECHO_APP_REQ_SUBPROTO)
1193 .header("Host", "server.example.com")
1194 .body(Body::empty())
1195 .unwrap();
1196
1197 let mut response = send_request.send_request(req).await.unwrap();
1198 let status = response.status();
1199 if status != 200 {
1200 let body = response.into_body().collect().await.unwrap().to_bytes();
1201 let body = std::str::from_utf8(&body).unwrap();
1202 panic!("response status was {status}: {body}");
1203 }
1204 let upgraded = hyper::upgrade::on(&mut response).await.unwrap();
1205 let upgraded = TokioIo::new(upgraded);
1206 let socket = WebSocketStream::from_raw_socket(upgraded, protocol::Role::Client, None).await;
1207 test_echo_app(socket, response.headers()).await;
1208 }
1209
1210 fn echo_app() -> Router {
1211 async fn handle_socket(mut socket: WebSocket) {
1212 assert_eq!(socket.protocol().unwrap(), "echo");
1213 while let Some(Ok(msg)) = socket.recv().await {
1214 match msg {
1215 Message::Text(_) | Message::Binary(_) | Message::Close(_) => {
1216 if socket.send(msg).await.is_err() {
1217 break;
1218 }
1219 }
1220 Message::Ping(_) | Message::Pong(_) => {
1221 }
1223 }
1224 }
1225 }
1226
1227 Router::new().route(
1228 "/echo",
1229 any(|ws: WebSocketUpgrade| {
1230 let ws = ws.protocols(["echo2", "echo"]);
1231 assert_eq!(ws.selected_protocol().unwrap(), "echo");
1232 ready(ws.on_upgrade(handle_socket))
1233 }),
1234 )
1235 }
1236
1237 const TEST_ECHO_APP_REQ_SUBPROTO: &str = "echo3, echo";
1238 async fn test_echo_app<S: AsyncRead + AsyncWrite + Unpin>(
1239 mut socket: WebSocketStream<S>,
1240 headers: &http::HeaderMap,
1241 ) {
1242 assert_eq!(headers[http::header::SEC_WEBSOCKET_PROTOCOL], "echo");
1243
1244 let input = tungstenite::Message::Text(tungstenite::Utf8Bytes::from_static("foobar"));
1245 socket.send(input.clone()).await.unwrap();
1246 let output = socket.next().await.unwrap().unwrap();
1247 assert_eq!(input, output);
1248
1249 socket
1250 .send(tungstenite::Message::Ping(Bytes::from_static(b"ping")))
1251 .await
1252 .unwrap();
1253 let output = socket.next().await.unwrap().unwrap();
1254 assert_eq!(
1255 output,
1256 tungstenite::Message::Pong(Bytes::from_static(b"ping"))
1257 );
1258 }
1259}