1mod builder;
117mod options;
118mod split;
119pub mod streaming;
120mod upgrade;
121
122use std::{
123 borrow::BorrowMut,
124 collections::VecDeque,
125 future::poll_fn,
126 io,
127 net::SocketAddr,
128 pin::{Pin, pin},
129 str::FromStr,
130 sync::Arc,
131 task::{Context, Poll, ready},
132 time::{Duration, Instant},
133};
134
135pub use builder::{HttpRequest, HttpRequestBuilder, WebSocketBuilder};
136use bytes::Bytes;
137use codec::Codec;
138use compression::{Compressor, Decompressor, WebSocketExtensions};
139pub use frame::{Frame, OpCode};
140use futures::{SinkExt, task::AtomicWaker};
141use http_body_util::Empty;
142use hyper::{Request, Response, StatusCode, body::Incoming, header, upgrade::Upgraded};
143use hyper_util::rt::TokioIo;
144pub use options::{CompressionLevel, DeflateOptions, Fragmentation, Options};
145pub use split::{ReadHalf, WriteHalf};
146use tokio::{
147 io::{AsyncRead, AsyncWrite},
148 net::TcpStream,
149};
150use tokio_rustls::{
151 TlsConnector,
152 rustls::{
153 self,
154 pki_types::{ServerName, TrustAnchor},
155 },
156};
157use tokio_util::codec::Framed;
158pub use upgrade::UpgradeFut;
159use url::Url;
160
161pub use crate::stream::MaybeTlsStream;
163use crate::{Result, WebSocketError, codec, compression, frame, streaming::Streaming};
164
165pub type TcpWebSocket = WebSocket<MaybeTlsStream<TcpStream>>;
170
171pub type HttpWebSocket = WebSocket<HttpStream>;
176
177#[cfg(feature = "axum")]
178pub use upgrade::IncomingUpgrade;
179
180pub const MAX_PAYLOAD_READ: usize = 1024 * 1024;
185
186pub const MAX_READ_BUFFER: usize = 2 * 1024 * 1024;
191
192pub type HttpResponse = Response<Empty<Bytes>>;
201
202pub type UpgradeResult = Result<(HttpResponse, UpgradeFut)>;
212
213#[derive(Debug, Default, Clone)]
215pub(crate) struct Negotiation {
216 pub(crate) extensions: Option<WebSocketExtensions>,
217 pub(crate) compression_level: Option<CompressionLevel>,
218 pub(crate) max_payload_read: usize,
219 pub(crate) max_read_buffer: usize,
220 pub(crate) utf8: bool,
221 pub(crate) fragmentation: Option<options::Fragmentation>,
222 pub(crate) max_backpressure_write_boundary: Option<usize>,
223}
224
225impl Negotiation {
226 pub(crate) fn decompressor(&self, role: Role) -> Option<Decompressor> {
227 let config = self.extensions.as_ref()?;
228
229 tracing::debug!(
230 "Established decompressor for {role} with settings \
231 client_no_context_takeover={} server_no_context_takeover={} \
232 server_max_window_bits={:?} client_max_window_bits={:?}",
233 config.client_no_context_takeover,
234 config.client_no_context_takeover,
235 config.server_max_window_bits,
236 config.client_max_window_bits
237 );
238
239 Some(if role == Role::Server {
241 if config.client_no_context_takeover {
242 Decompressor::no_context_takeover()
243 } else {
244 #[cfg(feature = "zlib")]
245 if let Some(Some(window_bits)) = config.client_max_window_bits {
246 Decompressor::new_with_window_bits(window_bits.max(9))
247 } else {
248 Decompressor::new()
249 }
250 #[cfg(not(feature = "zlib"))]
251 Decompressor::new()
252 }
253 } else {
254 if config.server_no_context_takeover {
256 Decompressor::no_context_takeover()
257 } else {
258 #[cfg(feature = "zlib")]
259 if let Some(Some(window_bits)) = config.server_max_window_bits {
260 Decompressor::new_with_window_bits(window_bits)
261 } else {
262 Decompressor::new()
263 }
264 #[cfg(not(feature = "zlib"))]
265 Decompressor::new()
266 }
267 })
268 }
269
270 pub(crate) fn compressor(&self, role: Role) -> Option<Compressor> {
271 let config = self.extensions.as_ref()?;
272
273 tracing::debug!(
274 "Established compressor for {role} with settings \
275 client_no_context_takeover={} server_no_context_takeover={} \
276 server_max_window_bits={:?} client_max_window_bits={:?}",
277 config.client_no_context_takeover,
278 config.client_no_context_takeover,
279 config.server_max_window_bits,
280 config.client_max_window_bits
281 );
282
283 let level = self.compression_level?;
284
285 Some(if role == Role::Client {
287 if config.client_no_context_takeover {
288 Compressor::no_context_takeover(level)
289 } else {
290 #[cfg(feature = "zlib")]
291 if let Some(Some(window_bits)) = config.client_max_window_bits {
292 Compressor::new_with_window_bits(level, window_bits)
293 } else {
294 Compressor::new(level)
295 }
296 #[cfg(not(feature = "zlib"))]
297 Compressor::new(level)
298 }
299 } else {
300 if config.server_no_context_takeover {
302 Compressor::no_context_takeover(level)
303 } else {
304 #[cfg(feature = "zlib")]
305 if let Some(Some(window_bits)) = config.server_max_window_bits {
306 Compressor::new_with_window_bits(level, window_bits)
307 } else {
308 Compressor::new(level)
309 }
310 #[cfg(not(feature = "zlib"))]
311 Compressor::new(level)
312 }
313 })
314 }
315}
316
317#[derive(Copy, Clone, PartialEq)]
322pub enum Role {
323 Server,
324 Client,
325}
326
327impl std::fmt::Display for Role {
328 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
329 match self {
330 Self::Server => write!(f, "server"),
331 Self::Client => write!(f, "client"),
332 }
333 }
334}
335
336#[derive(Clone, Copy)]
341enum ContextKind {
342 Read,
344 Write,
346}
347
348#[derive(Default)]
350struct WakeProxy {
351 read_waker: AtomicWaker,
353 write_waker: AtomicWaker,
355}
356
357impl futures::task::ArcWake for WakeProxy {
358 fn wake_by_ref(this: &Arc<Self>) {
359 this.read_waker.wake();
360 this.write_waker.wake();
361 }
362}
363
364impl WakeProxy {
365 #[inline]
366 fn set_waker(&self, kind: ContextKind, waker: &futures::task::Waker) {
367 match kind {
368 ContextKind::Read => {
369 self.read_waker.register(waker);
370 }
371 ContextKind::Write => {
372 self.write_waker.register(waker);
373 }
374 }
375 }
376
377 #[inline(always)]
378 fn with_context<F, R>(self: &Arc<Self>, f: F) -> R
379 where
380 F: FnOnce(&mut Context<'_>) -> R,
381 {
382 let waker = futures::task::waker_ref(self);
383 let mut cx = Context::from_waker(&waker);
384 f(&mut cx)
385 }
386}
387
388pub enum HttpStream {
390 Hyper(TokioIo<Upgraded>),
392}
393
394impl From<TokioIo<Upgraded>> for HttpStream {
395 fn from(value: TokioIo<Upgraded>) -> Self {
396 Self::Hyper(value)
397 }
398}
399
400impl AsyncRead for HttpStream {
401 fn poll_read(
402 self: Pin<&mut Self>,
403 cx: &mut Context<'_>,
404 buf: &mut tokio::io::ReadBuf<'_>,
405 ) -> Poll<io::Result<()>> {
406 match self.get_mut() {
407 Self::Hyper(stream) => pin!(stream).poll_read(cx, buf),
408 }
409 }
410}
411
412impl AsyncWrite for HttpStream {
413 fn poll_write(
414 self: Pin<&mut Self>,
415 cx: &mut Context<'_>,
416 buf: &[u8],
417 ) -> Poll<std::result::Result<usize, io::Error>> {
418 match self.get_mut() {
419 Self::Hyper(stream) => pin!(stream).poll_write(cx, buf),
420 }
421 }
422
423 fn poll_flush(
424 self: Pin<&mut Self>,
425 cx: &mut Context<'_>,
426 ) -> Poll<std::result::Result<(), io::Error>> {
427 match self.get_mut() {
428 Self::Hyper(stream) => pin!(stream).poll_flush(cx),
429 }
430 }
431
432 fn poll_shutdown(
433 self: Pin<&mut Self>,
434 cx: &mut Context<'_>,
435 ) -> Poll<std::result::Result<(), io::Error>> {
436 match self.get_mut() {
437 Self::Hyper(stream) => pin!(stream).poll_shutdown(cx),
438 }
439 }
440}
441
442pub(super) struct FragmentationState {
445 started: Instant,
446 opcode: OpCode,
447 is_compressed: bool,
448 bytes_read: usize,
449 parts: VecDeque<Bytes>,
450}
451
452struct FragmentLayer {
460 outgoing_fragments: VecDeque<Frame>,
462 incoming_fragment: Option<FragmentationState>,
464 fragment_size: Option<usize>,
466 max_read_buffer: usize,
468 fragment_timeout: Option<Duration>,
470}
471
472impl FragmentLayer {
473 fn new(
475 fragment_size: Option<usize>,
476 max_read_buffer: usize,
477 fragment_timeout: Option<Duration>,
478 ) -> Self {
479 Self {
480 outgoing_fragments: VecDeque::new(),
481 incoming_fragment: None,
482 fragment_size,
483 max_read_buffer,
484 fragment_timeout,
485 }
486 }
487
488 fn fragment_outgoing(&mut self, frame: Frame) {
492 if !frame.is_fin() && self.fragment_size.is_some() {
494 panic!(
495 "Fragment the frames yourself or use `fragment_size`, but not both. Use Streaming"
496 );
497 }
498
499 let max_fragment_size = self.fragment_size.unwrap_or(usize::MAX);
500 self.outgoing_fragments
501 .extend(frame.into_fragments(max_fragment_size));
502 }
503
504 #[inline(always)]
506 fn pop_outgoing_fragment(&mut self) -> Option<Frame> {
507 self.outgoing_fragments.pop_front()
508 }
509
510 #[inline(always)]
512 fn has_outgoing_fragments(&self) -> bool {
513 !self.outgoing_fragments.is_empty()
514 }
515
516 fn assemble_incoming(&mut self, mut frame: Frame) -> Result<Option<Frame>> {
523 use bytes::BufMut;
524
525 #[cfg(test)]
526 println!(
527 "<<Fragmentation<< OpCode={:?} fin={} len={}",
528 frame.opcode(),
529 frame.is_fin(),
530 frame.payload.len()
531 );
532
533 match frame.opcode {
534 OpCode::Text | OpCode::Binary => {
535 if self.incoming_fragment.is_some() {
537 return Err(WebSocketError::InvalidFragment);
538 }
539
540 if !frame.fin {
542 let fragmentation = FragmentationState {
543 started: Instant::now(),
544 opcode: frame.opcode,
545 is_compressed: frame.is_compressed,
546 bytes_read: frame.payload.len(),
547 parts: VecDeque::from([frame.payload]),
548 };
549 self.incoming_fragment = Some(fragmentation);
550
551 return Ok(None);
552 }
553
554 Ok(Some(frame))
556 }
557 OpCode::Continuation => {
558 let mut fragment = self
559 .incoming_fragment
560 .take()
561 .ok_or_else(|| WebSocketError::InvalidFragment)?;
562
563 fragment.bytes_read += frame.payload.len();
564
565 if fragment.bytes_read >= self.max_read_buffer {
567 return Err(WebSocketError::FrameTooLarge);
568 }
569
570 if let Some(timeout) = self.fragment_timeout
572 && fragment.started.elapsed() > timeout
573 {
574 return Err(WebSocketError::FragmentTimeout);
575 }
576
577 fragment.parts.push_back(frame.payload);
578
579 if frame.fin {
580 frame.opcode = fragment.opcode;
582 frame.is_compressed = fragment.is_compressed;
583 frame.payload = fragment
584 .parts
585 .into_iter()
586 .fold(
587 bytes::BytesMut::with_capacity(fragment.bytes_read),
588 |mut acc, b| {
589 acc.put(b);
590 acc
591 },
592 )
593 .freeze();
594
595 Ok(Some(frame))
596 } else {
597 self.incoming_fragment = Some(fragment);
598 Ok(None)
599 }
600 }
601 _ => {
602 Ok(Some(frame))
604 }
605 }
606 }
607}
608
609pub struct WebSocket<S> {
722 streaming: Streaming<S>,
723 check_utf8: bool,
724 fragment_layer: FragmentLayer,
726}
727
728impl WebSocket<MaybeTlsStream<TcpStream>> {
729 pub fn connect(url: Url) -> WebSocketBuilder {
750 WebSocketBuilder::new(url)
751 }
752
753 pub(crate) async fn connect_priv(
754 url: Url,
755 tcp_address: Option<SocketAddr>,
756 connector: Option<TlsConnector>,
757 options: Options,
758 builder: HttpRequestBuilder,
759 ) -> Result<TcpWebSocket> {
760 let host = url
761 .host()
762 .ok_or_else(|| WebSocketError::InvalidHttpScheme)?
763 .to_string();
764
765 let tcp_stream = if let Some(tcp_address) = tcp_address {
766 TcpStream::connect(tcp_address).await?
767 } else {
768 let port = url
769 .port_or_known_default()
770 .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "URL has no port"))?;
771 TcpStream::connect(format!("{host}:{port}")).await?
772 };
773
774 let _ = tcp_stream.set_nodelay(options.no_delay);
775
776 let stream = match url.scheme() {
777 "ws" => MaybeTlsStream::Plain(tcp_stream),
778 "wss" => {
779 let connector = connector.unwrap_or_else(tls_connector);
780 let domain = ServerName::try_from(host)
781 .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid dnsname"))?;
782
783 MaybeTlsStream::Tls(connector.connect(domain, tcp_stream).await?)
784 }
785 _ => return Err(WebSocketError::InvalidHttpScheme),
786 };
787
788 WebSocket::handshake_with_request(url, stream, options, builder).await
789 }
790}
791
792impl<S> WebSocket<S>
793where
794 S: AsyncRead + AsyncWrite + Send + Unpin + 'static,
795{
796 pub async fn handshake(url: Url, io: S, options: Options) -> Result<WebSocket<S>> {
843 Self::handshake_with_request(url, io, options, HttpRequest::builder()).await
844 }
845
846 pub async fn handshake_with_request(
898 url: Url,
899 io: S,
900 options: Options,
901 mut builder: HttpRequestBuilder,
902 ) -> Result<WebSocket<S>> {
903 if !builder
904 .headers_ref()
905 .map(|h| h.contains_key(header::HOST))
906 .unwrap_or(false)
907 {
908 let host = url
909 .host()
910 .ok_or(WebSocketError::InvalidHttpScheme)?
911 .to_string();
912
913 let is_port_defined = url.port().is_some();
914 let port = url
915 .port_or_known_default()
916 .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "URL has no port"))?;
917 let host_header = if is_port_defined {
918 format!("{host}:{port}")
919 } else {
920 host
921 };
922
923 builder = builder.header(header::HOST, host_header.as_str());
924 }
925
926 let target_url = &url[url::Position::BeforePath..];
927
928 let mut req = builder
929 .method("GET")
930 .uri(target_url)
931 .header(header::UPGRADE, "websocket")
932 .header(header::CONNECTION, "upgrade")
933 .header(header::SEC_WEBSOCKET_KEY, generate_key())
934 .header(header::SEC_WEBSOCKET_VERSION, "13")
935 .body(Empty::<Bytes>::new())
936 .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e.to_string()))?;
937
938 if let Some(compression) = options.compression.as_ref() {
939 let extensions = WebSocketExtensions::from(compression);
940 let header_value =
941 extensions
942 .to_string()
943 .parse()
944 .map_err(|e: header::InvalidHeaderValue| {
945 io::Error::new(io::ErrorKind::InvalidInput, e.to_string())
946 })?;
947 req.headers_mut()
948 .insert(header::SEC_WEBSOCKET_EXTENSIONS, header_value);
949 }
950
951 let (mut sender, conn) = hyper::client::conn::http1::handshake(TokioIo::new(io)).await?;
952
953 #[cfg(not(feature = "smol"))]
954 tokio::spawn(async move {
955 if let Err(err) = conn.with_upgrades().await {
956 tracing::error!("upgrading connection: {:?}", err);
957 }
958 });
959
960 #[cfg(feature = "smol")]
961 smol::spawn(async move {
962 if let Err(err) = conn.with_upgrades().await {
963 tracing::error!("upgrading connection: {:?}", err);
964 }
965 })
966 .detach();
967
968 let mut response = sender.send_request(req).await?;
969 let negotiated = verify(&response, options)?;
970
971 let upgraded = hyper::upgrade::on(&mut response).await?;
972 let parts = upgraded.downcast::<TokioIo<S>>().unwrap();
973
974 let stream = parts.io.into_inner();
976 let read_buf = parts.read_buf;
977
978 Ok(WebSocket::new(Role::Client, stream, read_buf, negotiated))
979 }
980}
981
982impl WebSocket<HttpStream> {
983 pub fn upgrade<B>(request: impl BorrowMut<Request<B>>) -> UpgradeResult {
987 Self::upgrade_with_options(request, Options::default())
988 }
989
990 pub fn upgrade_with_options<B>(
992 mut request: impl BorrowMut<Request<B>>,
993 options: Options,
994 ) -> UpgradeResult {
995 let request = request.borrow_mut();
996
997 let key = request
998 .headers()
999 .get(header::SEC_WEBSOCKET_KEY)
1000 .ok_or(WebSocketError::MissingSecWebSocketKey)?;
1001
1002 if request
1003 .headers()
1004 .get(header::SEC_WEBSOCKET_VERSION)
1005 .map(|v| v.as_bytes())
1006 != Some(b"13")
1007 {
1008 return Err(WebSocketError::InvalidSecWebsocketVersion);
1009 }
1010
1011 let maybe_compression = request
1012 .headers()
1013 .get(header::SEC_WEBSOCKET_EXTENSIONS)
1014 .and_then(|h| h.to_str().ok())
1015 .map(WebSocketExtensions::from_str)
1016 .and_then(std::result::Result::ok);
1017
1018 let mut response = Response::builder()
1019 .status(hyper::StatusCode::SWITCHING_PROTOCOLS)
1020 .header(hyper::header::CONNECTION, "upgrade")
1021 .header(hyper::header::UPGRADE, "websocket")
1022 .header(
1023 header::SEC_WEBSOCKET_ACCEPT,
1024 upgrade::sec_websocket_protocol(key.as_bytes()),
1025 )
1026 .body(Empty::new())
1027 .map_err(|e| {
1028 io::Error::new(io::ErrorKind::InvalidInput, format!("build response: {e}"))
1029 })?;
1030
1031 let extensions = if let Some(client_compression) = maybe_compression {
1032 if let Some(server_compression) = options.compression.as_ref() {
1033 let offer = server_compression.merge(&client_compression);
1034
1035 let header_value =
1036 offer
1037 .to_string()
1038 .parse()
1039 .map_err(|e: header::InvalidHeaderValue| {
1040 io::Error::new(
1041 io::ErrorKind::InvalidInput,
1042 format!("extension header: {e}"),
1043 )
1044 })?;
1045 response
1046 .headers_mut()
1047 .insert(header::SEC_WEBSOCKET_EXTENSIONS, header_value);
1048
1049 Some(offer)
1050 } else {
1051 None
1052 }
1053 } else {
1054 None
1055 };
1056
1057 let max_read_buffer = options.max_read_buffer.unwrap_or(
1058 options
1059 .max_payload_read
1060 .map(|payload_read| payload_read * 2)
1061 .unwrap_or(MAX_READ_BUFFER),
1062 );
1063
1064 let stream = UpgradeFut {
1065 inner: hyper::upgrade::on(request),
1066 negotiation: Some(Negotiation {
1067 extensions,
1068 compression_level: options
1069 .compression
1070 .as_ref()
1071 .map(|compression| compression.level),
1072 max_payload_read: options.max_payload_read.unwrap_or(MAX_PAYLOAD_READ),
1073 max_read_buffer,
1074 utf8: options.check_utf8,
1075 fragmentation: options.fragmentation.clone(),
1076 max_backpressure_write_boundary: options.max_backpressure_write_boundary,
1077 }),
1078 };
1079
1080 Ok((response, stream))
1081 }
1082}
1083
1084impl<S> WebSocket<S>
1087where
1088 S: AsyncRead + AsyncWrite + Unpin,
1089{
1090 pub unsafe fn split_stream(self) -> (Framed<S, Codec>, ReadHalf, WriteHalf) {
1095 unsafe { self.streaming.split_stream() }
1097 }
1098
1099 pub fn into_streaming(self) -> Streaming<S> {
1131 self.streaming
1132 }
1133
1134 pub fn poll_next_frame(&mut self, cx: &mut Context<'_>) -> Poll<Result<Frame>> {
1136 loop {
1137 let frame = ready!(self.streaming.poll_next_frame(cx))?;
1138 match self.on_frame(frame)? {
1139 Some(ok) => break Poll::Ready(Ok(ok)),
1140 None => continue,
1141 }
1142 }
1143 }
1144
1145 pub async fn next_frame(&mut self) -> Result<Frame> {
1147 poll_fn(|cx| self.poll_next_frame(cx)).await
1148 }
1149
1150 pub(crate) fn new(role: Role, stream: S, read_buf: Bytes, opts: Negotiation) -> Self {
1155 Self {
1156 streaming: Streaming::new(role, stream, read_buf, &opts),
1157 check_utf8: opts.utf8,
1158 fragment_layer: FragmentLayer::new(
1159 opts.fragmentation.as_ref().and_then(|f| f.fragment_size),
1160 opts.max_read_buffer,
1161 opts.fragmentation.as_ref().and_then(|f| f.timeout),
1162 ),
1163 }
1164 }
1165
1166 fn on_frame(&mut self, frame: Frame) -> Result<Option<Frame>> {
1167 let frame = match self.fragment_layer.assemble_incoming(frame)? {
1168 Some(frame) => frame,
1169 None => return Ok(None), };
1171
1172 if frame.opcode == OpCode::Text && self.check_utf8 {
1173 #[cfg(not(feature = "simd"))]
1174 if std::str::from_utf8(&frame.payload).is_err() {
1175 return Err(WebSocketError::InvalidUTF8);
1176 }
1177 #[cfg(feature = "simd")]
1178 if simdutf8::basic::from_utf8(&frame.payload).is_err() {
1179 return Err(WebSocketError::InvalidUTF8);
1180 }
1181 }
1182
1183 Ok(Some(frame))
1184 }
1185}
1186
1187impl<S> futures::Stream for WebSocket<S>
1188where
1189 S: AsyncRead + AsyncWrite + Unpin,
1190{
1191 type Item = Frame;
1192
1193 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
1194 let this = self.get_mut();
1195 match ready!(this.poll_next_frame(cx)) {
1196 Ok(ok) => Poll::Ready(Some(ok)),
1197 Err(_) => Poll::Ready(None),
1198 }
1199 }
1200}
1201
1202impl<S> futures::Sink<Frame> for WebSocket<S>
1203where
1204 S: AsyncRead + AsyncWrite + Unpin,
1205{
1206 type Error = WebSocketError;
1207
1208 fn poll_ready(
1209 self: Pin<&mut Self>,
1210 cx: &mut Context<'_>,
1211 ) -> Poll<std::result::Result<(), Self::Error>> {
1212 let this = self.get_mut();
1213 this.streaming.poll_ready_unpin(cx)
1214 }
1215
1216 fn start_send(self: Pin<&mut Self>, item: Frame) -> std::result::Result<(), Self::Error> {
1217 let this = self.get_mut();
1218 this.fragment_layer.fragment_outgoing(item);
1219 Ok(())
1220 }
1221
1222 fn poll_flush(
1223 self: Pin<&mut Self>,
1224 cx: &mut Context<'_>,
1225 ) -> Poll<std::result::Result<(), Self::Error>> {
1226 let this = self.get_mut();
1227
1228 while this.fragment_layer.has_outgoing_fragments() {
1230 ready!(this.streaming.poll_ready_unpin(cx))?;
1233 let fragment = this
1234 .fragment_layer
1235 .pop_outgoing_fragment()
1236 .expect("fragment");
1237 this.streaming.start_send_unpin(fragment)?;
1238 }
1239
1240 this.streaming.poll_flush_unpin(cx)
1242 }
1243
1244 fn poll_close(
1245 self: Pin<&mut Self>,
1246 cx: &mut Context<'_>,
1247 ) -> Poll<std::result::Result<(), Self::Error>> {
1248 let this = self.get_mut();
1249 this.streaming.poll_close_unpin(cx)
1250 }
1251}
1252
1253fn verify(response: &Response<Incoming>, options: Options) -> Result<Negotiation> {
1256 if response.status() != StatusCode::SWITCHING_PROTOCOLS {
1257 return Err(WebSocketError::InvalidStatusCode(
1258 response.status().as_u16(),
1259 ));
1260 }
1261
1262 let compression_level = options.compression.as_ref().map(|opts| opts.level);
1263 let headers = response.headers();
1264
1265 if !headers
1266 .get(header::UPGRADE)
1267 .and_then(|h| h.to_str().ok())
1268 .map(|h| h.eq_ignore_ascii_case("websocket"))
1269 .unwrap_or(false)
1270 {
1271 return Err(WebSocketError::InvalidUpgradeHeader);
1272 }
1273
1274 if !headers
1275 .get(header::CONNECTION)
1276 .and_then(|h| h.to_str().ok())
1277 .map(|h| h.eq_ignore_ascii_case("Upgrade"))
1278 .unwrap_or(false)
1279 {
1280 return Err(WebSocketError::InvalidConnectionHeader);
1281 }
1282
1283 let extensions = headers
1284 .get(header::SEC_WEBSOCKET_EXTENSIONS)
1285 .and_then(|h| h.to_str().ok())
1286 .map(WebSocketExtensions::from_str)
1287 .and_then(std::result::Result::ok);
1288
1289 let max_read_buffer = options.max_read_buffer.unwrap_or(
1290 options
1291 .max_payload_read
1292 .map(|payload_read| payload_read * 2)
1293 .unwrap_or(MAX_READ_BUFFER),
1294 );
1295
1296 Ok(Negotiation {
1297 extensions,
1298 compression_level,
1299 max_payload_read: options.max_payload_read.unwrap_or(MAX_PAYLOAD_READ),
1300 max_read_buffer,
1301 utf8: options.check_utf8,
1302 fragmentation: options.fragmentation.clone(),
1303 max_backpressure_write_boundary: options.max_backpressure_write_boundary,
1304 })
1305}
1306
1307fn generate_key() -> String {
1308 use base64::prelude::*;
1309 let input: [u8; 16] = rand::random();
1310 BASE64_STANDARD.encode(input)
1311}
1312
1313fn tls_connector() -> TlsConnector {
1315 let mut root_cert_store = rustls::RootCertStore::empty();
1316 root_cert_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().map(|ta| TrustAnchor {
1317 subject: ta.subject.clone(),
1318 subject_public_key_info: ta.subject_public_key_info.clone(),
1319 name_constraints: ta.name_constraints.clone(),
1320 }));
1321
1322 let maybe_provider = rustls::crypto::CryptoProvider::get_default().cloned();
1323
1324 #[cfg(any(feature = "rustls-ring", feature = "rustls-aws-lc-rs"))]
1325 let provider = maybe_provider.unwrap_or_else(|| {
1326 #[cfg(feature = "rustls-ring")]
1327 let _provider = rustls::crypto::ring::default_provider();
1328 #[cfg(feature = "rustls-aws-lc-rs")]
1329 let _provider = rustls::crypto::aws_lc_rs::default_provider();
1330
1331 Arc::new(_provider)
1332 });
1333
1334 #[cfg(not(any(feature = "rustls-ring", feature = "rustls-aws-lc-rs")))]
1335 let provider = maybe_provider.expect(
1336 r#"No Rustls crypto provider was enabled for yawc to connect to a `wss://` endpoint!
1337
1338Either:
1339 - provide a `connector` in the WebSocketBuilder options
1340 - enable one of the following features: `rustls-ring`, `rustls-aws-lc-rs`"#,
1341 );
1342
1343 let mut config = rustls::ClientConfig::builder_with_provider(provider)
1344 .with_protocol_versions(rustls::ALL_VERSIONS)
1345 .expect("versions")
1346 .with_root_certificates(root_cert_store)
1347 .with_no_client_auth();
1348 config.alpn_protocols = vec!["http/1.1".into()];
1349
1350 TlsConnector::from(Arc::new(config))
1351}
1352
1353#[cfg(test)]
1354mod tests {
1355 use std::{
1356 pin::Pin,
1357 task::{Context, Poll},
1358 };
1359
1360 use futures::SinkExt;
1361 use tokio::io::{AsyncRead, AsyncWrite, DuplexStream, ReadBuf};
1362
1363 use super::*;
1364 use crate::close::{self, CloseCode};
1365
1366 struct MockStream {
1368 inner: DuplexStream,
1369 }
1370
1371 impl MockStream {
1372 fn pair(buffer_size: usize) -> (Self, Self) {
1374 let (a, b) = tokio::io::duplex(buffer_size);
1375 (Self { inner: a }, Self { inner: b })
1376 }
1377 }
1378
1379 impl AsyncRead for MockStream {
1380 fn poll_read(
1381 mut self: Pin<&mut Self>,
1382 cx: &mut Context<'_>,
1383 buf: &mut ReadBuf<'_>,
1384 ) -> Poll<io::Result<()>> {
1385 Pin::new(&mut self.inner).poll_read(cx, buf)
1386 }
1387 }
1388
1389 impl AsyncWrite for MockStream {
1390 fn poll_write(
1391 mut self: Pin<&mut Self>,
1392 cx: &mut Context<'_>,
1393 buf: &[u8],
1394 ) -> Poll<io::Result<usize>> {
1395 Pin::new(&mut self.inner).poll_write(cx, buf)
1396 }
1397
1398 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
1399 Pin::new(&mut self.inner).poll_flush(cx)
1400 }
1401
1402 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
1403 Pin::new(&mut self.inner).poll_shutdown(cx)
1404 }
1405 }
1406
1407 fn create_websocket_pair(buffer_size: usize) -> (WebSocket<MockStream>, WebSocket<MockStream>) {
1409 create_websocket_pair_with_config(buffer_size, None, None)
1410 }
1411
1412 fn create_websocket_pair_with_config(
1413 buffer_size: usize,
1414 fragment_size: Option<usize>,
1415 compression_level: Option<CompressionLevel>,
1416 ) -> (WebSocket<MockStream>, WebSocket<MockStream>) {
1417 let (client_stream, server_stream) = MockStream::pair(buffer_size);
1418
1419 let extensions = compression_level.map(|_level| WebSocketExtensions {
1420 server_max_window_bits: None,
1421 client_max_window_bits: None,
1422 server_no_context_takeover: false,
1423 client_no_context_takeover: false,
1424 });
1425
1426 let negotiation = Negotiation {
1427 extensions,
1428 compression_level,
1429 max_payload_read: MAX_PAYLOAD_READ,
1430 max_read_buffer: MAX_READ_BUFFER,
1431 utf8: false,
1432 fragmentation: fragment_size.map(|size| options::Fragmentation {
1433 timeout: None,
1434 fragment_size: Some(size),
1435 }),
1436 max_backpressure_write_boundary: None,
1437 };
1438
1439 let client_ws = WebSocket::new(
1440 Role::Client,
1441 client_stream,
1442 Bytes::new(),
1443 negotiation.clone(),
1444 );
1445
1446 let server_ws = WebSocket::new(Role::Server, server_stream, Bytes::new(), negotiation);
1447
1448 (client_ws, server_ws)
1449 }
1450
1451 #[tokio::test]
1452 async fn test_send_and_receive_text_frame() {
1453 let (mut client, mut server) = create_websocket_pair(1024);
1454
1455 let text = "Hello, WebSocket!";
1456 client
1457 .send(Frame::text(text))
1458 .await
1459 .expect("Failed to send text frame");
1460
1461 let frame = server.next_frame().await.expect("Failed to receive frame");
1462
1463 assert_eq!(frame.opcode(), OpCode::Text);
1464 assert_eq!(frame.payload(), text.as_bytes());
1465 assert!(frame.is_fin());
1466 }
1467
1468 #[tokio::test]
1469 async fn test_send_and_receive_binary_frame() {
1470 let (mut client, mut server) = create_websocket_pair(1024);
1471
1472 let data = vec![1u8, 2, 3, 4, 5];
1473 client
1474 .send(Frame::binary(data.clone()))
1475 .await
1476 .expect("Failed to send binary frame");
1477
1478 let frame = server.next_frame().await.expect("Failed to receive frame");
1479
1480 assert_eq!(frame.opcode(), OpCode::Binary);
1481 assert_eq!(frame.payload(), &data[..]);
1482 assert!(frame.is_fin());
1483 }
1484
1485 #[tokio::test]
1486 async fn test_bidirectional_communication() {
1487 let (mut client, mut server) = create_websocket_pair(2048);
1488
1489 client
1490 .send(Frame::text("Client message"))
1491 .await
1492 .expect("Failed to send from client");
1493
1494 let frame = server
1495 .next_frame()
1496 .await
1497 .expect("Failed to receive at server");
1498 assert_eq!(frame.payload(), b"Client message" as &[u8]);
1499
1500 server
1501 .send(Frame::text("Server response"))
1502 .await
1503 .expect("Failed to send from server");
1504
1505 let frame = client
1506 .next_frame()
1507 .await
1508 .expect("Failed to receive at client");
1509 assert_eq!(frame.payload(), b"Server response" as &[u8]);
1510 }
1511
1512 #[tokio::test]
1513 async fn test_ping_pong() {
1514 let (mut client, mut server) = create_websocket_pair(1024);
1515
1516 client
1521 .send(Frame::pong("pong_data"))
1522 .await
1523 .expect("Failed to send pong");
1524
1525 let frame = server.next_frame().await.expect("Failed to receive pong");
1526 assert_eq!(frame.opcode(), OpCode::Pong);
1527 assert_eq!(frame.payload(), b"pong_data" as &[u8]);
1528 }
1529
1530 #[tokio::test]
1531 async fn test_close_frame() {
1532 let (mut client, mut server) = create_websocket_pair(1024);
1533
1534 client
1535 .send(Frame::close(CloseCode::Normal, b"Goodbye"))
1536 .await
1537 .expect("Failed to send close frame");
1538
1539 let frame = server
1540 .next_frame()
1541 .await
1542 .expect("Failed to receive close frame");
1543
1544 assert_eq!(frame.opcode(), OpCode::Close);
1545 assert_eq!(frame.close_code(), Some(close::CloseCode::Normal));
1546 assert_eq!(
1547 frame.close_reason().expect("Invalid close reason"),
1548 Some("Goodbye")
1549 );
1550 }
1551
1552 #[tokio::test]
1553 async fn test_large_message() {
1554 let (mut client, mut server) = create_websocket_pair(65536);
1555
1556 let large_data = vec![42u8; 10240];
1557 client
1558 .send(Frame::binary(large_data.clone()))
1559 .await
1560 .expect("Failed to send large message");
1561
1562 let frame = server
1563 .next_frame()
1564 .await
1565 .expect("Failed to receive large message");
1566
1567 assert_eq!(frame.opcode(), OpCode::Binary);
1568 assert_eq!(frame.payload().len(), 10240);
1569 assert_eq!(frame.payload(), &large_data[..]);
1570 }
1571
1572 #[tokio::test]
1573 async fn test_multiple_messages() {
1574 let (mut client, mut server) = create_websocket_pair(4096);
1575
1576 for i in 0..10 {
1577 let msg = format!("Message {}", i);
1578 client
1579 .send(Frame::text(msg.clone()))
1580 .await
1581 .expect("Failed to send message");
1582
1583 let frame = server
1584 .next_frame()
1585 .await
1586 .expect("Failed to receive message");
1587 assert_eq!(frame.payload(), msg.as_bytes());
1588 }
1589 }
1590
1591 #[tokio::test]
1592 async fn test_empty_payload() {
1593 let (mut client, mut server) = create_websocket_pair(1024);
1594
1595 client
1596 .send(Frame::text(Bytes::new()))
1597 .await
1598 .expect("Failed to send empty frame");
1599
1600 let frame = server
1601 .next_frame()
1602 .await
1603 .expect("Failed to receive empty frame");
1604
1605 assert_eq!(frame.opcode(), OpCode::Text);
1606 assert_eq!(frame.payload().len(), 0);
1607 }
1608
1609 #[tokio::test]
1610 async fn test_fragmented_message() {
1611 let (mut client, mut server) = create_websocket_pair(2048);
1612
1613 let mut frame1 = Frame::text("Hello, ");
1614 frame1.set_fin(false);
1615 client
1616 .send(frame1)
1617 .await
1618 .expect("Failed to send first fragment");
1619
1620 let frame2 = Frame::continuation("World!");
1621 client
1622 .send(frame2)
1623 .await
1624 .expect("Failed to send final fragment");
1625
1626 let received = server
1629 .next_frame()
1630 .await
1631 .expect("Failed to receive message");
1632 assert_eq!(received.opcode(), OpCode::Text);
1633 assert!(received.is_fin());
1634 assert_eq!(received.payload(), b"Hello, World!" as &[u8]);
1635 }
1636
1637 #[tokio::test]
1638 async fn test_concurrent_send_receive() {
1639 let (mut client, mut server) = create_websocket_pair(4096);
1640
1641 let client_task = tokio::spawn(async move {
1642 for i in 0..5 {
1643 client
1644 .send(Frame::text(format!("Client {}", i)))
1645 .await
1646 .expect("Failed to send from client");
1647
1648 let frame = client
1649 .next_frame()
1650 .await
1651 .expect("Failed to receive at client");
1652 assert_eq!(frame.payload(), format!("Server {}", i).as_bytes());
1653 }
1654 client
1655 });
1656
1657 let server_task = tokio::spawn(async move {
1658 for i in 0..5 {
1659 let frame = server
1660 .next_frame()
1661 .await
1662 .expect("Failed to receive at server");
1663 assert_eq!(frame.payload(), format!("Client {}", i).as_bytes());
1664
1665 server
1666 .send(Frame::text(format!("Server {}", i)))
1667 .await
1668 .expect("Failed to send from server");
1669 }
1670 server
1671 });
1672
1673 client_task.await.expect("Client task failed");
1674 server_task.await.expect("Server task failed");
1675 }
1676
1677 #[tokio::test]
1678 async fn test_utf8_validation() {
1679 let (mut client, mut server) = create_websocket_pair(1024);
1680
1681 let valid_utf8 = "Hello, 世界! 🌍";
1682 client
1683 .send(Frame::text(valid_utf8))
1684 .await
1685 .expect("Failed to send UTF-8 text");
1686
1687 let frame = server
1688 .next_frame()
1689 .await
1690 .expect("Failed to receive UTF-8 text");
1691 assert_eq!(frame.opcode(), OpCode::Text);
1692 assert!(frame.is_utf8());
1693 assert_eq!(std::str::from_utf8(frame.payload()).unwrap(), valid_utf8);
1694 }
1695
1696 #[tokio::test]
1697 async fn test_stream_trait_implementation() {
1698 use futures::StreamExt;
1699
1700 let (mut client, mut server) = create_websocket_pair(1024);
1701
1702 tokio::spawn(async move {
1703 for i in 0..3 {
1704 client
1705 .send(Frame::text(format!("Message {}", i)))
1706 .await
1707 .expect("Failed to send message");
1708 }
1709 });
1710
1711 let mut count = 0;
1712 while let Some(frame) = server.next().await {
1713 assert_eq!(frame.opcode(), OpCode::Text);
1714 count += 1;
1715 if count == 3 {
1716 break;
1717 }
1718 }
1719 assert_eq!(count, 3);
1720 }
1721
1722 #[tokio::test]
1723 async fn test_sink_trait_implementation() {
1724 use futures::SinkExt;
1725
1726 let (mut client, mut server) = create_websocket_pair(1024);
1727
1728 client
1729 .send(Frame::text("Sink message"))
1730 .await
1731 .expect("Failed to send via Sink");
1732
1733 client.flush().await.expect("Failed to flush");
1734
1735 let frame = server
1736 .next_frame()
1737 .await
1738 .expect("Failed to receive message");
1739 assert_eq!(frame.payload(), b"Sink message" as &[u8]);
1740 }
1741
1742 #[tokio::test]
1743 async fn test_rapid_small_messages() {
1744 let (mut client, mut server) = create_websocket_pair(8192);
1745
1746 let count = 100;
1747
1748 let sender = tokio::spawn(async move {
1749 for i in 0..count {
1750 client
1751 .send(Frame::text(format!("{}", i)))
1752 .await
1753 .expect("Failed to send");
1754 }
1755 client
1756 });
1757
1758 for i in 0..count {
1759 let frame = server.next_frame().await.expect("Failed to receive");
1760 assert_eq!(frame.payload(), format!("{}", i).as_bytes());
1761 }
1762
1763 sender.await.expect("Sender task failed");
1764 }
1765
1766 #[tokio::test]
1767 async fn test_interleaved_control_and_data_frames() {
1768 let (mut client, mut server) = create_websocket_pair(2048);
1769
1770 client
1771 .send(Frame::text("Data 1"))
1772 .await
1773 .expect("Failed to send");
1774
1775 client
1778 .send(Frame::pong("pong"))
1779 .await
1780 .expect("Failed to send pong");
1781
1782 client
1783 .send(Frame::binary(vec![1, 2, 3]))
1784 .await
1785 .expect("Failed to send");
1786
1787 let f1 = server.next_frame().await.expect("Failed to receive");
1788 assert_eq!(f1.opcode(), OpCode::Text);
1789 assert_eq!(f1.payload(), b"Data 1" as &[u8]);
1790
1791 let f2 = server.next_frame().await.expect("Failed to receive");
1792 assert_eq!(f2.opcode(), OpCode::Pong);
1793
1794 let f3 = server.next_frame().await.expect("Failed to receive");
1795 assert_eq!(f3.opcode(), OpCode::Binary);
1796 assert_eq!(f3.payload(), &[1u8, 2, 3] as &[u8]);
1797 }
1798
1799 #[tokio::test]
1800 async fn test_client_sends_masked_frames() {
1801 let (mut client, mut _server) = create_websocket_pair(1024);
1802
1803 let frame = Frame::text("test");
1805 client.send(frame).await.expect("Failed to send");
1806
1807 }
1811
1812 #[tokio::test]
1813 async fn test_server_sends_unmasked_frames() {
1814 let (mut _client, mut server) = create_websocket_pair(1024);
1815
1816 let frame = Frame::text("test");
1818 server.send(frame).await.expect("Failed to send");
1819
1820 }
1822
1823 #[tokio::test]
1824 async fn test_close_code_variants() {
1825 let (mut client, mut server) = create_websocket_pair(1024);
1826
1827 client
1828 .send(Frame::close(close::CloseCode::Away, b""))
1829 .await
1830 .expect("Failed to send close");
1831
1832 let frame = server.next_frame().await.expect("Failed to receive");
1833 assert_eq!(frame.close_code(), Some(close::CloseCode::Away));
1834 }
1835
1836 #[tokio::test]
1837 async fn test_multiple_fragments() {
1838 let (mut client, mut server) = create_websocket_pair(4096);
1839
1840 for i in 0..5 {
1842 let is_last = i == 4;
1843 let opcode = if i == 0 {
1844 OpCode::Text
1845 } else {
1846 OpCode::Continuation
1847 };
1848
1849 let mut frame = Frame::from((opcode, format!("part{}", i)));
1850 frame.set_fin(is_last);
1851 client.send(frame).await.expect("Failed to send fragment");
1852 }
1853
1854 let frame = server.next_frame().await.expect("Failed to receive");
1857 assert_eq!(frame.opcode(), OpCode::Text);
1858 assert!(frame.is_fin());
1859
1860 let expected = "part0part1part2part3part4";
1862 assert_eq!(frame.payload(), expected.as_bytes());
1863 }
1864
1865 #[tokio::test]
1866 async fn test_automatic_fragmentation_large_messages() {
1867 let (client_stream, server_stream) = MockStream::pair(8192);
1869
1870 let negotiation = Negotiation {
1871 extensions: None,
1872 compression_level: None,
1873 max_payload_read: MAX_PAYLOAD_READ,
1874 max_read_buffer: MAX_READ_BUFFER,
1875 utf8: false,
1876 fragmentation: Some(options::Fragmentation {
1877 timeout: None,
1878 fragment_size: Some(100),
1879 }),
1880 max_backpressure_write_boundary: None,
1881 };
1882
1883 let mut client_ws = WebSocket::new(
1884 Role::Client,
1885 client_stream,
1886 Bytes::new(),
1887 negotiation.clone(),
1888 );
1889
1890 let mut server_ws = WebSocket::new(Role::Server, server_stream, Bytes::new(), negotiation);
1891
1892 let large_payload = vec![b'A'; 300];
1894 client_ws
1895 .send(Frame::binary(large_payload.clone()))
1896 .await
1897 .unwrap();
1898
1899 let received = server_ws.next_frame().await.unwrap();
1901 assert_eq!(received.opcode(), OpCode::Binary);
1902 assert_eq!(received.payload(), large_payload.as_slice());
1903 }
1904
1905 #[tokio::test]
1906 async fn test_automatic_fragmentation_small_messages() {
1907 let (client_stream, server_stream) = MockStream::pair(8192);
1909
1910 let negotiation = Negotiation {
1911 extensions: None,
1912 compression_level: None,
1913 max_payload_read: MAX_PAYLOAD_READ,
1914 max_read_buffer: MAX_READ_BUFFER,
1915 utf8: false,
1916 fragmentation: Some(options::Fragmentation {
1917 timeout: None,
1918 fragment_size: Some(100),
1919 }),
1920 max_backpressure_write_boundary: None,
1921 };
1922
1923 let mut client_ws = WebSocket::new(
1924 Role::Client,
1925 client_stream,
1926 Bytes::new(),
1927 negotiation.clone(),
1928 );
1929
1930 let mut server_ws = WebSocket::new(Role::Server, server_stream, Bytes::new(), negotiation);
1931
1932 let small_payload = vec![b'B'; 50];
1934 client_ws
1935 .send(Frame::text(small_payload.clone()))
1936 .await
1937 .unwrap();
1938
1939 let received = server_ws.next_frame().await.unwrap();
1941 assert_eq!(received.opcode(), OpCode::Text);
1942 assert_eq!(received.payload(), small_payload.as_slice());
1943 }
1944
1945 #[tokio::test]
1946 async fn test_no_fragmentation_when_not_configured() {
1947 let (client_stream, server_stream) = MockStream::pair(8192);
1949
1950 let negotiation = Negotiation {
1951 extensions: None,
1952 compression_level: None,
1953 max_payload_read: MAX_PAYLOAD_READ,
1954 max_read_buffer: MAX_READ_BUFFER,
1955 utf8: false,
1956 fragmentation: None,
1957 max_backpressure_write_boundary: None,
1958 };
1959
1960 let mut client_ws = WebSocket::new(
1961 Role::Client,
1962 client_stream,
1963 Bytes::new(),
1964 negotiation.clone(),
1965 );
1966
1967 let mut server_ws = WebSocket::new(Role::Server, server_stream, Bytes::new(), negotiation);
1968
1969 let large_payload = vec![b'C'; 1000];
1971 client_ws
1972 .send(Frame::binary(large_payload.clone()))
1973 .await
1974 .unwrap();
1975
1976 let received = server_ws.next_frame().await.unwrap();
1978 assert_eq!(received.opcode(), OpCode::Binary);
1979 assert_eq!(received.payload(), large_payload.as_slice());
1980 }
1981
1982 #[tokio::test]
1983 async fn test_interleave_control_frames_with_continuation_frames() {
1984 let (mut client, mut server) = create_websocket_pair(4096);
1992
1993 let mut fragment1 = Frame::text("Hello, ");
1995 fragment1.set_fin(false);
1996 client
1997 .send(fragment1)
1998 .await
1999 .expect("Failed to send first fragment");
2000
2001 client
2003 .send(Frame::ping("ping during fragmentation"))
2004 .await
2005 .expect("Failed to send ping");
2006
2007 let mut fragment2 = Frame::continuation("World");
2009 fragment2.set_fin(false);
2010 client
2011 .send(fragment2)
2012 .await
2013 .expect("Failed to send second fragment");
2014
2015 client
2017 .send(Frame::pong("pong during fragmentation"))
2018 .await
2019 .expect("Failed to send pong");
2020
2021 let fragment3 = Frame::continuation("!");
2023 client
2024 .send(fragment3)
2025 .await
2026 .expect("Failed to send final fragment");
2027
2028 let ping_frame = server
2030 .next_frame()
2031 .await
2032 .expect("Failed to receive ping frame");
2033 assert_eq!(ping_frame.opcode(), OpCode::Ping);
2034 assert_eq!(ping_frame.payload(), b"ping during fragmentation" as &[u8]);
2035
2036 let pong_frame = server
2038 .next_frame()
2039 .await
2040 .expect("Failed to receive pong frame");
2041 assert_eq!(pong_frame.opcode(), OpCode::Pong);
2042 assert_eq!(pong_frame.payload(), b"pong during fragmentation" as &[u8]);
2043
2044 let message_frame = server
2046 .next_frame()
2047 .await
2048 .expect("Failed to receive reassembled message");
2049 assert_eq!(message_frame.opcode(), OpCode::Text);
2050 assert!(message_frame.is_fin());
2051 assert_eq!(message_frame.payload(), b"Hello, World!" as &[u8]);
2052 }
2053
2054 #[tokio::test]
2055 async fn test_large_compressed_fragmented_payload() {
2056 const FRAGMENT_SIZE: usize = 65536;
2064 const PAYLOAD_SIZE: usize = 1024 * 1024; use flate2::Compression;
2067
2068 let (mut client, mut server) = create_websocket_pair_with_config(
2069 256 * 1024, None, Some(Compression::best()),
2072 );
2073
2074 let payload: Vec<u8> = (0..PAYLOAD_SIZE).map(|i| (i % 256) as u8).collect();
2076
2077 let total_fragments = PAYLOAD_SIZE.div_ceil(FRAGMENT_SIZE);
2079 println!(
2080 "Sending {} bytes in {} fragments of {} bytes each",
2081 PAYLOAD_SIZE, total_fragments, FRAGMENT_SIZE
2082 );
2083
2084 let server_task = tokio::spawn(async move {
2086 server
2087 .next_frame()
2088 .await
2089 .expect("Failed to receive large payload")
2090 });
2091
2092 let mut offset = 0;
2094 let mut fragment_num = 0;
2095
2096 while offset < PAYLOAD_SIZE {
2097 let end = std::cmp::min(offset + FRAGMENT_SIZE, PAYLOAD_SIZE);
2098 let chunk = payload[offset..end].to_vec();
2099 let is_final = end == PAYLOAD_SIZE;
2100
2101 let mut frame = if fragment_num == 0 {
2102 Frame::binary(chunk)
2104 } else {
2105 Frame::continuation(chunk)
2107 };
2108
2109 frame.set_fin(is_final);
2111
2112 println!(
2113 "Sending fragment {}/{}: {} bytes, OpCode={:?} FIN={}",
2114 fragment_num + 1,
2115 total_fragments,
2116 frame.payload().len(),
2117 frame.opcode(),
2118 is_final
2119 );
2120
2121 client
2122 .send(frame)
2123 .await
2124 .unwrap_or_else(|_| panic!("Failed to send fragment {}", fragment_num + 1));
2125
2126 offset = end;
2127 fragment_num += 1;
2128 }
2129
2130 let received_frame = server_task.await.expect("Server task failed");
2132
2133 assert_eq!(received_frame.opcode(), OpCode::Binary);
2135 assert!(received_frame.is_fin());
2136 assert_eq!(received_frame.payload().len(), PAYLOAD_SIZE);
2137 assert_eq!(received_frame.payload().as_ref(), &payload[..]);
2138
2139 println!(
2140 "Successfully sent {} manual fragments, compressed, decompressed, and reassembled {} bytes",
2141 total_fragments, PAYLOAD_SIZE
2142 );
2143 }
2144
2145 #[tokio::test]
2146 async fn test_compressed_fragmented_with_interleaved_control() {
2147 const FRAGMENT_SIZE: usize = 65536;
2154
2155 use flate2::Compression;
2156
2157 let (mut client, mut server) = create_websocket_pair_with_config(
2158 128 * 1024,
2159 Some(FRAGMENT_SIZE),
2160 Some(Compression::best()),
2161 );
2162
2163 let payload = "This is a test payload that should compress well. ".repeat(5000);
2165 let original_payload = payload.clone();
2166 let payload_bytes = payload.as_bytes().to_vec();
2167
2168 tokio::spawn(async move {
2170 client
2171 .send(Frame::binary(payload_bytes))
2172 .await
2173 .expect("Failed to send payload");
2174
2175 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
2177 client
2178 .send(Frame::ping("test"))
2179 .await
2180 .expect("Failed to send ping");
2181 });
2182
2183 let mut received_message = None;
2185 let mut received_ping = false;
2186
2187 for _ in 0..2 {
2188 let frame = server.next_frame().await.expect("Failed to receive frame");
2189
2190 match frame.opcode() {
2191 OpCode::Binary => {
2192 assert!(frame.is_fin());
2193 received_message = Some(frame.payload().to_vec());
2194 }
2195 OpCode::Ping => {
2196 received_ping = true;
2197 }
2198 _ => panic!("Unexpected frame type: {:?}", frame.opcode()),
2199 }
2200 }
2201
2202 assert!(received_message.is_some(), "Message not received");
2203 assert!(received_ping, "Ping not received");
2204
2205 let received = String::from_utf8(received_message.unwrap())
2206 .expect("Invalid UTF-8 in received payload");
2207
2208 assert_eq!(
2209 received, original_payload,
2210 "Compressed fragmented payload mismatch"
2211 );
2212 }
2213}