1#![cfg_attr(docsrs, feature(doc_cfg))]
157
158mod close;
159mod error;
160mod fragment;
161mod frame;
162#[cfg(feature = "upgrade")]
164#[cfg_attr(docsrs, doc(cfg(feature = "upgrade")))]
165pub mod handshake;
166mod mask;
167#[cfg(feature = "upgrade")]
169#[cfg_attr(docsrs, doc(cfg(feature = "upgrade")))]
170pub mod upgrade;
171
172use bytes::Buf;
173
174use bytes::BytesMut;
175#[cfg(feature = "unstable-split")]
176use std::future::Future;
177
178use tokio::io::AsyncRead;
179#[cfg(feature = "unstable-split")]
180use std::future::Future;
181
182use tokio::io::AsyncReadExt;
183use tokio::io::AsyncWrite;
184use tokio::io::AsyncWriteExt;
185pub use crate::close::CloseCode;
186pub use crate::error::WebSocketError;
187pub use crate::fragment::FragmentCollector;
188#[cfg(feature = "unstable-split")]
189pub use crate::fragment::FragmentCollectorRead;
190pub use crate::frame::Frame;
191pub use crate::frame::OpCode;
192pub use crate::frame::Payload;
193pub use crate::mask::unmask;
194
195#[derive(Copy, Clone, PartialEq)]
196pub enum Role {
197 Server,
198 Client,
199}
200
201pub(crate) struct WriteHalf {
202 role: Role,
203 closed: bool,
204 vectored: bool,
205 auto_apply_mask: bool,
206 writev_threshold: usize,
207 write_buffer: Vec<u8>,
208}
209
210pub(crate) struct ReadHalf {
211 role: Role,
212 auto_apply_mask: bool,
213 auto_close: bool,
214 auto_pong: bool,
215 writev_threshold: usize,
216 max_message_size: usize,
217 buffer: BytesMut,
218}
219
220#[cfg(feature = "unstable-split")]
221pub struct WebSocketRead<S> {
222 stream: S,
223 read_half: ReadHalf,
224}
225
226#[cfg(feature = "unstable-split")]
227pub struct WebSocketWrite<S> {
228 stream: S,
229 write_half: WriteHalf,
230}
231
232#[cfg(feature = "unstable-split")]
233pub fn after_handshake_split<R, W>(
235 read: R,
236 write: W,
237 role: Role,
238) -> (WebSocketRead<R>, WebSocketWrite<W>)
239where
240 R: AsyncRead + Unpin,
241 W: AsyncWrite + Unpin,
242{
243 (
244 WebSocketRead {
245 stream: read,
246 read_half: ReadHalf::after_handshake(role),
247 },
248 WebSocketWrite {
249 stream: write,
250 write_half: WriteHalf::after_handshake(role),
251 },
252 )
253}
254
255#[cfg(feature = "unstable-split")]
256impl<'f, S> WebSocketRead<S> {
257 #[inline]
259 pub(crate) fn into_parts_internal(self) -> (S, ReadHalf) {
260 (self.stream, self.read_half)
261 }
262
263 pub fn set_writev_threshold(&mut self, threshold: usize) {
264 self.read_half.writev_threshold = threshold;
265 }
266
267 pub fn set_auto_close(&mut self, auto_close: bool) {
271 self.read_half.auto_close = auto_close;
272 }
273
274 pub fn set_auto_pong(&mut self, auto_pong: bool) {
278 self.read_half.auto_pong = auto_pong;
279 }
280
281 pub fn set_max_message_size(&mut self, max_message_size: usize) {
285 self.read_half.max_message_size = max_message_size;
286 }
287
288 pub fn set_auto_apply_mask(&mut self, auto_apply_mask: bool) {
292 self.read_half.auto_apply_mask = auto_apply_mask;
293 }
294
295 pub async fn read_frame<R, E>(
297 &mut self,
298 send_fn: &mut impl FnMut(Frame<'f>) -> R,
299 ) -> Result<Frame, WebSocketError>
300 where
301 S: AsyncRead + Unpin,
302 E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
303 R: Future<Output = Result<(), E>>,
304 {
305 loop {
306 let (res, obligated_send) =
307 self.read_half.read_frame_inner(&mut self.stream).await;
308 if let Some(frame) = obligated_send {
309 let res = send_fn(frame).await;
310 res.map_err(|e| WebSocketError::SendError(e.into()))?;
311 }
312 if let Some(frame) = res? {
313 break Ok(frame);
314 }
315 }
316 }
317}
318
319#[cfg(feature = "unstable-split")]
320impl<'f, S> WebSocketWrite<S> {
321 pub fn set_writev(&mut self, vectored: bool) {
325 self.write_half.vectored = vectored;
326 }
327
328 pub fn set_writev_threshold(&mut self, threshold: usize) {
329 self.write_half.writev_threshold = threshold;
330 }
331
332 pub fn set_auto_apply_mask(&mut self, auto_apply_mask: bool) {
336 self.write_half.auto_apply_mask = auto_apply_mask;
337 }
338
339 pub fn is_closed(&self) -> bool {
340 self.write_half.closed
341 }
342
343 pub async fn write_frame(
344 &mut self,
345 frame: Frame<'f>,
346 ) -> Result<(), WebSocketError>
347 where
348 S: AsyncWrite + Unpin,
349 {
350 self.write_half.write_frame(&mut self.stream, frame).await
351 }
352
353 pub async fn flush(&mut self) -> Result<(), WebSocketError>
354 where
355 S: AsyncWrite + Unpin,
356 {
357 flush(&mut self.stream).await
358 }
359}
360
361#[inline]
362async fn flush<S>(stream: &mut S) -> Result<(), WebSocketError>
363where
364 S: AsyncWrite + Unpin,
365{
366 stream.flush().await.map_err(WebSocketError::IoError)
367}
368
369#[cfg(feature = "unstable-split")]
370pub struct WebSocketRead<S> {
371 stream: S,
372 read_half: ReadHalf,
373}
374
375#[cfg(feature = "unstable-split")]
376pub struct WebSocketWrite<S> {
377 stream: S,
378 write_half: WriteHalf,
379}
380
381#[cfg(feature = "unstable-split")]
382pub fn after_handshake_split<R, W>(
384 read: R,
385 write: W,
386 role: Role,
387) -> (WebSocketRead<R>, WebSocketWrite<W>)
388where
389 R: AsyncRead + Unpin,
390 W: AsyncWrite + Unpin,
391{
392 (
393 WebSocketRead {
394 stream: read,
395 read_half: ReadHalf::after_handshake(role),
396 },
397 WebSocketWrite {
398 stream: write,
399 write_half: WriteHalf::after_handshake(role),
400 },
401 )
402}
403
404#[cfg(feature = "unstable-split")]
405impl<'f, S> WebSocketRead<S> {
406 #[inline]
408 pub(crate) fn into_parts_internal(self) -> (S, ReadHalf) {
409 (self.stream, self.read_half)
410 }
411
412 pub fn set_writev_threshold(&mut self, threshold: usize) {
413 self.read_half.writev_threshold = threshold;
414 }
415
416 pub fn set_auto_close(&mut self, auto_close: bool) {
420 self.read_half.auto_close = auto_close;
421 }
422
423 pub fn set_auto_pong(&mut self, auto_pong: bool) {
427 self.read_half.auto_pong = auto_pong;
428 }
429
430 pub fn set_max_message_size(&mut self, max_message_size: usize) {
434 self.read_half.max_message_size = max_message_size;
435 }
436
437 pub fn set_auto_apply_mask(&mut self, auto_apply_mask: bool) {
441 self.read_half.auto_apply_mask = auto_apply_mask;
442 }
443
444 pub async fn read_frame<R, E>(
446 &mut self,
447 send_fn: &mut impl FnMut(Frame<'f>) -> R,
448 ) -> Result<Frame, WebSocketError>
449 where
450 S: AsyncRead + Unpin,
451 E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
452 R: Future<Output = Result<(), E>>,
453 {
454 loop {
455 let (res, obligated_send) =
456 self.read_half.read_frame_inner(&mut self.stream).await;
457 if let Some(frame) = obligated_send {
458 let res = send_fn(frame).await;
459 res.map_err(|e| WebSocketError::SendError(e.into()))?;
460 }
461 if let Some(frame) = res? {
462 break Ok(frame);
463 }
464 }
465 }
466}
467
468#[cfg(feature = "unstable-split")]
469impl<'f, S> WebSocketWrite<S> {
470 pub fn set_writev(&mut self, vectored: bool) {
474 self.write_half.vectored = vectored;
475 }
476
477 pub fn set_writev_threshold(&mut self, threshold: usize) {
478 self.write_half.writev_threshold = threshold;
479 }
480
481 pub fn set_auto_apply_mask(&mut self, auto_apply_mask: bool) {
485 self.write_half.auto_apply_mask = auto_apply_mask;
486 }
487
488 pub fn is_closed(&self) -> bool {
489 self.write_half.closed
490 }
491
492 pub async fn write_frame(
493 &mut self,
494 frame: Frame<'f>,
495 ) -> Result<(), WebSocketError>
496 where
497 S: AsyncWrite + Unpin,
498 {
499 self.write_half.write_frame(&mut self.stream, frame).await
500 }
501
502 pub async fn flush(&mut self) -> Result<(), WebSocketError>
503 where
504 S: AsyncWrite + Unpin,
505 {
506 flush(&mut self.stream).await
507 }
508}
509
510pub struct WebSocket<S> {
512 stream: S,
513 write_half: WriteHalf,
514 read_half: ReadHalf,
515}
516
517impl<'f, S> WebSocket<S> {
518 pub fn after_handshake(stream: S, role: Role) -> Self
538 where
539 S: AsyncRead + AsyncWrite + Unpin,
540 S: AsyncRead + AsyncWrite + Unpin,
541 {
542 Self {
543 stream,
544 write_half: WriteHalf::after_handshake(role),
545 read_half: ReadHalf::after_handshake(role),
546 }
547 }
548
549 #[cfg(feature = "unstable-split")]
553 pub fn split<R, W>(
554 self,
555 split_fn: impl Fn(S) -> (R, W),
556 ) -> (WebSocketRead<R>, WebSocketWrite<W>)
557 where
558 S: AsyncRead + AsyncWrite + Unpin,
559 R: AsyncRead + Unpin,
560 W: AsyncWrite + Unpin,
561 {
562 let (stream, read, write) = self.into_parts_internal();
563 let (r, w) = split_fn(stream);
564 (
565 WebSocketRead {
566 stream: r,
567 read_half: read,
568 write_half: WriteHalf::after_handshake(role),
569 read_half: ReadHalf::after_handshake(role),
570 }
571 )
572 }
573
574 #[cfg(feature = "unstable-split")]
578 pub fn split<R, W>(
579 self,
580 split_fn: impl Fn(S) -> (R, W),
581 ) -> (WebSocketRead<R>, WebSocketWrite<W>)
582 where
583 S: AsyncRead + AsyncWrite + Unpin,
584 R: AsyncRead + Unpin,
585 W: AsyncWrite + Unpin,
586 {
587 let (stream, read, write) = self.into_parts_internal();
588 let (r, w) = split_fn(stream);
589 (
590 WebSocketRead {
591 stream: r,
592 read_half: read,
593 },
594 WebSocketWrite {
595 stream: w,
596 write_half: write,
597 },
598 )
599 }
600
601 #[inline]
603 pub fn into_inner(self) -> S {
604 self.stream
606 }
607
608 #[inline]
610 pub(crate) fn into_parts_internal(self) -> (S, ReadHalf, WriteHalf) {
611 (self.stream, self.read_half, self.write_half)
612 }
613
614 pub fn set_writev(&mut self, vectored: bool) {
618 self.write_half.vectored = vectored;
619 }
620
621 pub fn set_writev_threshold(&mut self, threshold: usize) {
622 self.read_half.writev_threshold = threshold;
623 self.write_half.writev_threshold = threshold;
624 }
625
626 pub fn set_auto_close(&mut self, auto_close: bool) {
630 self.read_half.auto_close = auto_close;
631 }
632
633 pub fn set_auto_pong(&mut self, auto_pong: bool) {
637 self.read_half.auto_pong = auto_pong;
638 }
639
640 pub fn set_max_message_size(&mut self, max_message_size: usize) {
644 self.read_half.max_message_size = max_message_size;
645 }
646
647 pub fn set_auto_apply_mask(&mut self, auto_apply_mask: bool) {
651 self.read_half.auto_apply_mask = auto_apply_mask;
652 self.write_half.auto_apply_mask = auto_apply_mask;
653 }
654
655 pub fn is_closed(&self) -> bool {
656 self.write_half.closed
657 }
658
659
660 pub async fn write_frame(
678 &mut self,
679 frame: Frame<'f>,
680 ) -> Result<(), WebSocketError>
681 where
682 S: AsyncRead + AsyncWrite + Unpin,
683 S: AsyncRead + AsyncWrite + Unpin,
684 {
685 self.write_half.write_frame(&mut self.stream, frame).await?;
686 Ok(())
687 }
688
689 pub async fn flush(&mut self) -> Result<(), WebSocketError>
695 where
696 S: AsyncWrite + Unpin,
697 {
698 flush(&mut self.stream).await
699 }
700
701 pub async fn read_frame(&mut self) -> Result<Frame<'f>, WebSocketError>
733 where
734 S: AsyncRead + AsyncWrite + Unpin,
735 S: AsyncRead + AsyncWrite + Unpin,
736 {
737 loop {
738 let (res, obligated_send) =
739 self.read_half.read_frame_inner(&mut self.stream).await;
740 let is_closed = self.write_half.closed;
741 if let Some(frame) = obligated_send {
742 if !is_closed {
743 self.write_half.write_frame(&mut self.stream, frame).await?;
744 }
745 }
746 if let Some(frame) = res? {
747 if is_closed && frame.opcode != OpCode::Close {
748 return Err(WebSocketError::ConnectionClosed);
749 }
750 break Ok(frame);
751 }
752 }
753 }
754}
755
756const MAX_HEADER_SIZE: usize = 14;
757
758impl ReadHalf {
759 pub fn after_handshake(role: Role) -> Self {
760 let buffer = BytesMut::with_capacity(8192);
761
762 Self {
763 role,
764 auto_apply_mask: true,
765 auto_close: true,
766 auto_pong: true,
767 writev_threshold: 1024,
768 max_message_size: 64 << 20,
769 buffer,
770 }
771 }
772
773 pub(crate) async fn read_frame_inner<'f, S>(
780 &mut self,
781 stream: &mut S,
782 ) -> (Result<Option<Frame<'f>>, WebSocketError>, Option<Frame<'f>>)
783 where
784 S: AsyncRead + Unpin,
785 {
786 let mut frame = match self.parse_frame_header(stream).await {
787 Ok(frame) => frame,
788 Err(e) => return (Err(e), None),
789 };
790
791 if self.role == Role::Server && self.auto_apply_mask {
792 frame.unmask()
793 };
794
795 match frame.opcode {
796 OpCode::Close if self.auto_close => {
797 match frame.payload.len() {
798 0 => {}
799 1 => return (Err(WebSocketError::InvalidCloseFrame), None),
800 _ => {
801 let code = close::CloseCode::from(u16::from_be_bytes(
802 frame.payload[0..2].try_into().unwrap(),
803 ));
804
805 #[cfg(feature = "simd")]
806 if simdutf8::basic::from_utf8(&frame.payload[2..]).is_err() {
807 return (Err(WebSocketError::InvalidUTF8), None);
808 };
809
810 #[cfg(not(feature = "simd"))]
811 if std::str::from_utf8(&frame.payload[2..]).is_err() {
812 return (Err(WebSocketError::InvalidUTF8), None);
813 };
814
815 if !code.is_allowed() {
816 return (
817 Err(WebSocketError::InvalidCloseCode),
818 Some(Frame::close(1002, &frame.payload[2..])),
819 );
820 }
821 }
822 };
823
824 let obligated_send = Frame::close_raw(frame.payload.to_owned().into());
825 (Ok(Some(frame)), Some(obligated_send))
826 }
827 OpCode::Ping if self.auto_pong => {
828 (Ok(None), Some(Frame::pong(frame.payload)))
829 }
830 OpCode::Text => {
831 if frame.fin && !frame.is_utf8() {
832 (Err(WebSocketError::InvalidUTF8), None)
833 } else {
834 (Ok(Some(frame)), None)
835 }
836 }
837 _ => (Ok(Some(frame)), None),
838 }
839 }
840
841 async fn parse_frame_header<'a, S>(
842 &mut self,
843 stream: &mut S,
844 ) -> Result<Frame<'a>, WebSocketError>
845 where
846 S: AsyncRead + Unpin,
847 {
848 macro_rules! eof {
849 ($n:expr) => {{
850 if $n == 0 {
851 return Err(WebSocketError::UnexpectedEOF);
852 }
853 }};
854 }
855
856 while self.buffer.remaining() < 2 {
858 eof!(stream.read_buf(&mut self.buffer).await?);
859 }
860
861 let fin = self.buffer[0] & 0b10000000 != 0;
862 let rsv1 = self.buffer[0] & 0b01000000 != 0;
863 let rsv2 = self.buffer[0] & 0b00100000 != 0;
864 let rsv3 = self.buffer[0] & 0b00010000 != 0;
865
866 if rsv1 || rsv2 || rsv3 {
867 return Err(WebSocketError::ReservedBitsNotZero);
868 }
869
870 let opcode = frame::OpCode::try_from(self.buffer[0] & 0b00001111)?;
871 let masked = self.buffer[1] & 0b10000000 != 0;
872
873 let length_code = self.buffer[1] & 0x7F;
874 let extra = match length_code {
875 126 => 2,
876 127 => 8,
877 _ => 0,
878 };
879
880 self.buffer.advance(2);
881 while self.buffer.remaining() < extra + masked as usize * 4 {
882 eof!(stream.read_buf(&mut self.buffer).await?);
883 }
884
885 #[allow(unexpected_cfgs)]
886 let payload_len: usize = match extra {
887 0 => usize::from(length_code),
888 2 => self.buffer.get_u16() as usize,
889 #[cfg(any(target_pointer_width = "64", target_pointer_width = "128"))]
890 8 => self.buffer.get_u64() as usize,
891 #[cfg(any(
893 target_pointer_width = "8",
894 target_pointer_width = "16",
895 target_pointer_width = "32"
896 ))]
897 8 => match usize::try_from(self.buffer.get_u64()) {
898 Ok(length) => length,
899 Err(_) => return Err(WebSocketError::FrameTooLarge),
900 },
901 _ => unreachable!(),
902 };
903
904 let mask = if masked {
905 Some(self.buffer.get_u32().to_be_bytes())
906 } else {
907 None
908 };
909
910 if frame::is_control(opcode) && !fin {
911 return Err(WebSocketError::ControlFrameFragmented);
912 }
913
914 if opcode == OpCode::Ping && payload_len > 125 {
915 return Err(WebSocketError::PingFrameTooLarge);
916 }
917
918 if payload_len >= self.max_message_size {
919 return Err(WebSocketError::FrameTooLarge);
920 }
921
922 self.buffer.reserve(payload_len + MAX_HEADER_SIZE);
926 while payload_len > self.buffer.remaining() {
927 eof!(stream.read_buf(&mut self.buffer).await?);
928 }
929
930 let payload = self.buffer.split_to(payload_len);
932 let frame = Frame::new(fin, opcode, mask, Payload::Bytes(payload));
933 Ok(frame)
934 }
935}
936
937impl WriteHalf {
938 pub fn after_handshake(role: Role) -> Self {
939 Self {
940 role,
941 closed: false,
942 auto_apply_mask: true,
943 vectored: true,
944 writev_threshold: 1024,
945 write_buffer: Vec::with_capacity(2),
946 }
947 }
948
949 pub async fn write_frame<'a, S>(
951 &'a mut self,
952 stream: &mut S,
953 mut frame: Frame<'a>,
954 ) -> Result<(), WebSocketError>
955 where
956 S: AsyncWrite + Unpin,
957 {
958 if self.role == Role::Client && self.auto_apply_mask {
959 frame.mask();
960 }
961
962 if frame.opcode == OpCode::Close {
963 self.closed = true;
964 } else if self.closed {
965 return Err(WebSocketError::ConnectionClosed);
966 }
967
968 if self.vectored && frame.payload.len() > self.writev_threshold {
969 frame.writev(stream).await?;
970 } else {
971 let text = frame.write(&mut self.write_buffer);
972 stream.write_all(text).await?;
973 }
974
975 Ok(())
976 }
977}
978
979#[cfg(test)]
980mod tests {
981 use super::*;
982
983 const _: () = {
984 const fn assert_unsync<S>() {
985 trait AmbiguousIfImpl<A> {
987 fn some_item() {}
989 }
990
991 impl<T: ?Sized> AmbiguousIfImpl<()> for T {}
992
993 #[allow(dead_code)]
996 struct Invalid;
997
998 impl<T: ?Sized + Sync> AmbiguousIfImpl<Invalid> for T {}
999
1000 let _ = <S as AmbiguousIfImpl<_>>::some_item;
1004 }
1005 assert_unsync::<WebSocket<tokio::net::TcpStream>>();
1006 };
1007}