1#![cfg_attr(docsrs, feature(doc_cfg))]
151
152mod close;
153mod error;
154mod fragment;
155mod frame;
156#[cfg(feature = "upgrade")]
158#[cfg_attr(docsrs, doc(cfg(feature = "upgrade")))]
159pub mod handshake;
160mod mask;
161#[cfg(feature = "upgrade")]
163#[cfg_attr(docsrs, doc(cfg(feature = "upgrade")))]
164pub mod upgrade;
165
166use bytes::Buf;
167
168use bytes::BytesMut;
169#[cfg(feature = "unstable-split")]
170use std::future::Future;
171
172use tokio::io::AsyncRead;
173use tokio::io::AsyncReadExt;
174use tokio::io::AsyncWrite;
175use tokio::io::AsyncWriteExt;
176
177pub use crate::close::CloseCode;
178pub use crate::error::WebSocketError;
179pub use crate::fragment::FragmentCollector;
180#[cfg(feature = "unstable-split")]
181pub use crate::fragment::FragmentCollectorRead;
182pub use crate::frame::Frame;
183pub use crate::frame::OpCode;
184pub use crate::frame::Payload;
185pub use crate::mask::unmask;
186
187#[derive(Copy, Clone, PartialEq)]
188pub enum Role {
189 Server,
190 Client,
191}
192
193pub(crate) struct WriteHalf {
194 role: Role,
195 closed: bool,
196 vectored: bool,
197 auto_apply_mask: bool,
198 writev_threshold: usize,
199 write_buffer: Vec<u8>,
200}
201
202pub(crate) struct ReadHalf {
203 role: Role,
204 auto_apply_mask: bool,
205 auto_close: bool,
206 auto_pong: bool,
207 writev_threshold: usize,
208 max_message_size: usize,
209 buffer: BytesMut,
210}
211
212#[cfg(feature = "unstable-split")]
213pub struct WebSocketRead<S> {
214 stream: S,
215 read_half: ReadHalf,
216}
217
218#[cfg(feature = "unstable-split")]
219pub struct WebSocketWrite<S> {
220 stream: S,
221 write_half: WriteHalf,
222}
223
224#[cfg(feature = "unstable-split")]
225pub fn after_handshake_split<R, W>(
227 read: R,
228 write: W,
229 role: Role,
230) -> (WebSocketRead<R>, WebSocketWrite<W>)
231where
232 R: AsyncRead + Unpin,
233 W: AsyncWrite + Unpin,
234{
235 (
236 WebSocketRead {
237 stream: read,
238 read_half: ReadHalf::after_handshake(role),
239 },
240 WebSocketWrite {
241 stream: write,
242 write_half: WriteHalf::after_handshake(role),
243 },
244 )
245}
246
247#[cfg(feature = "unstable-split")]
248impl<'f, S> WebSocketRead<S> {
249 #[inline]
251 pub(crate) fn into_parts_internal(self) -> (S, ReadHalf) {
252 (self.stream, self.read_half)
253 }
254
255 pub fn set_writev_threshold(&mut self, threshold: usize) {
256 self.read_half.writev_threshold = threshold;
257 }
258
259 pub fn set_auto_close(&mut self, auto_close: bool) {
263 self.read_half.auto_close = auto_close;
264 }
265
266 pub fn set_auto_pong(&mut self, auto_pong: bool) {
270 self.read_half.auto_pong = auto_pong;
271 }
272
273 pub fn set_max_message_size(&mut self, max_message_size: usize) {
277 self.read_half.max_message_size = max_message_size;
278 }
279
280 pub fn set_auto_apply_mask(&mut self, auto_apply_mask: bool) {
284 self.read_half.auto_apply_mask = auto_apply_mask;
285 }
286
287 pub async fn read_frame<R, E>(
289 &mut self,
290 send_fn: &mut impl FnMut(Frame<'f>) -> R,
291 ) -> Result<Frame<'_>, WebSocketError>
292 where
293 S: AsyncRead + Unpin,
294 E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
295 R: Future<Output = Result<(), E>>,
296 {
297 loop {
298 let (res, obligated_send) =
299 self.read_half.read_frame_inner(&mut self.stream).await;
300 if let Some(frame) = obligated_send {
301 let res = send_fn(frame).await;
302 res.map_err(|e| WebSocketError::SendError(e.into()))?;
303 }
304 if let Some(frame) = res? {
305 break Ok(frame);
306 }
307 }
308 }
309}
310
311#[cfg(feature = "unstable-split")]
312impl<'f, S> WebSocketWrite<S> {
313 pub fn set_writev(&mut self, vectored: bool) {
317 self.write_half.vectored = vectored;
318 }
319
320 pub fn set_writev_threshold(&mut self, threshold: usize) {
321 self.write_half.writev_threshold = threshold;
322 }
323
324 pub fn set_auto_apply_mask(&mut self, auto_apply_mask: bool) {
328 self.write_half.auto_apply_mask = auto_apply_mask;
329 }
330
331 pub fn is_closed(&self) -> bool {
332 self.write_half.closed
333 }
334
335 pub async fn write_frame(
336 &mut self,
337 frame: Frame<'f>,
338 ) -> Result<(), WebSocketError>
339 where
340 S: AsyncWrite + Unpin,
341 {
342 self.write_half.write_frame(&mut self.stream, frame).await
343 }
344
345 pub async fn flush(&mut self) -> Result<(), WebSocketError>
346 where
347 S: AsyncWrite + Unpin,
348 {
349 flush(&mut self.stream).await
350 }
351}
352
353#[inline]
354async fn flush<S>(stream: &mut S) -> Result<(), WebSocketError>
355where
356 S: AsyncWrite + Unpin,
357{
358 stream.flush().await.map_err(WebSocketError::IoError)
359}
360
361pub struct WebSocket<S> {
363 stream: S,
364 write_half: WriteHalf,
365 read_half: ReadHalf,
366}
367
368impl<'f, S> WebSocket<S> {
369 pub fn after_handshake(stream: S, role: Role) -> Self
389 where
390 S: AsyncRead + AsyncWrite + Unpin,
391 {
392 Self {
393 stream,
394 write_half: WriteHalf::after_handshake(role),
395 read_half: ReadHalf::after_handshake(role),
396 }
397 }
398
399 #[cfg(feature = "unstable-split")]
403 pub fn split<R, W>(
404 self,
405 split_fn: impl Fn(S) -> (R, W),
406 ) -> (WebSocketRead<R>, WebSocketWrite<W>)
407 where
408 S: AsyncRead + AsyncWrite + Unpin,
409 R: AsyncRead + Unpin,
410 W: AsyncWrite + Unpin,
411 {
412 let (stream, read, write) = self.into_parts_internal();
413 let (r, w) = split_fn(stream);
414 (
415 WebSocketRead {
416 stream: r,
417 read_half: read,
418 },
419 WebSocketWrite {
420 stream: w,
421 write_half: write,
422 },
423 )
424 }
425
426 #[inline]
428 pub fn into_inner(self) -> S {
429 self.stream
431 }
432
433 #[inline]
435 pub(crate) fn into_parts_internal(self) -> (S, ReadHalf, WriteHalf) {
436 (self.stream, self.read_half, self.write_half)
437 }
438
439 pub fn set_writev(&mut self, vectored: bool) {
443 self.write_half.vectored = vectored;
444 }
445
446 pub fn set_writev_threshold(&mut self, threshold: usize) {
447 self.read_half.writev_threshold = threshold;
448 self.write_half.writev_threshold = threshold;
449 }
450
451 pub fn set_auto_close(&mut self, auto_close: bool) {
455 self.read_half.auto_close = auto_close;
456 }
457
458 pub fn set_auto_pong(&mut self, auto_pong: bool) {
462 self.read_half.auto_pong = auto_pong;
463 }
464
465 pub fn set_max_message_size(&mut self, max_message_size: usize) {
469 self.read_half.max_message_size = max_message_size;
470 }
471
472 pub fn set_auto_apply_mask(&mut self, auto_apply_mask: bool) {
476 self.read_half.auto_apply_mask = auto_apply_mask;
477 self.write_half.auto_apply_mask = auto_apply_mask;
478 }
479
480 pub fn is_closed(&self) -> bool {
481 self.write_half.closed
482 }
483
484 pub async fn write_frame(
502 &mut self,
503 frame: Frame<'f>,
504 ) -> Result<(), WebSocketError>
505 where
506 S: AsyncRead + AsyncWrite + Unpin,
507 {
508 self.write_half.write_frame(&mut self.stream, frame).await?;
509 Ok(())
510 }
511
512 pub async fn flush(&mut self) -> Result<(), WebSocketError>
518 where
519 S: AsyncWrite + Unpin,
520 {
521 flush(&mut self.stream).await
522 }
523
524 pub async fn read_frame(&mut self) -> Result<Frame<'f>, WebSocketError>
551 where
552 S: AsyncRead + AsyncWrite + Unpin,
553 {
554 loop {
555 let (res, obligated_send) =
556 self.read_half.read_frame_inner(&mut self.stream).await;
557 let is_closed = self.write_half.closed;
558 if let Some(frame) = obligated_send
559 && !is_closed
560 {
561 self.write_half.write_frame(&mut self.stream, frame).await?;
562 }
563 if let Some(frame) = res? {
564 if is_closed && frame.opcode != OpCode::Close {
565 return Err(WebSocketError::ConnectionClosed);
566 }
567 break Ok(frame);
568 }
569 }
570 }
571}
572
573const MAX_HEADER_SIZE: usize = 14;
574
575impl ReadHalf {
576 pub fn after_handshake(role: Role) -> Self {
577 let buffer = BytesMut::with_capacity(8192);
578
579 Self {
580 role,
581 auto_apply_mask: true,
582 auto_close: true,
583 auto_pong: true,
584 writev_threshold: 1024,
585 max_message_size: 64 << 20,
586 buffer,
587 }
588 }
589
590 pub(crate) async fn read_frame_inner<'f, S>(
597 &mut self,
598 stream: &mut S,
599 ) -> (Result<Option<Frame<'f>>, WebSocketError>, Option<Frame<'f>>)
600 where
601 S: AsyncRead + Unpin,
602 {
603 let mut frame = match self.parse_frame_header(stream).await {
604 Ok(frame) => frame,
605 Err(e) => return (Err(e), None),
606 };
607
608 if self.role == Role::Server && self.auto_apply_mask {
609 frame.unmask()
610 };
611
612 match frame.opcode {
613 OpCode::Close if self.auto_close => {
614 match frame.payload.len() {
615 0 => {}
616 1 => return (Err(WebSocketError::InvalidCloseFrame), None),
617 _ => {
618 let code = close::CloseCode::from(u16::from_be_bytes(
619 frame.payload[0..2].try_into().unwrap(),
620 ));
621
622 #[cfg(feature = "simd")]
623 if simdutf8::basic::from_utf8(&frame.payload[2..]).is_err() {
624 return (Err(WebSocketError::InvalidUTF8), None);
625 };
626
627 #[cfg(not(feature = "simd"))]
628 if std::str::from_utf8(&frame.payload[2..]).is_err() {
629 return (Err(WebSocketError::InvalidUTF8), None);
630 };
631
632 if !code.is_allowed() {
633 return (
634 Err(WebSocketError::InvalidCloseCode),
635 Some(Frame::close(1002, &frame.payload[2..])),
636 );
637 }
638 }
639 };
640
641 let obligated_send = Frame::close_raw(frame.payload.to_owned().into());
642 (Ok(Some(frame)), Some(obligated_send))
643 }
644 OpCode::Ping if self.auto_pong => {
645 (Ok(None), Some(Frame::pong(frame.payload)))
646 }
647 OpCode::Text => {
648 if frame.fin && !frame.is_utf8() {
649 (Err(WebSocketError::InvalidUTF8), None)
650 } else {
651 (Ok(Some(frame)), None)
652 }
653 }
654 _ => (Ok(Some(frame)), None),
655 }
656 }
657
658 async fn parse_frame_header<'a, S>(
659 &mut self,
660 stream: &mut S,
661 ) -> Result<Frame<'a>, WebSocketError>
662 where
663 S: AsyncRead + Unpin,
664 {
665 macro_rules! eof {
666 ($n:expr) => {{
667 if $n == 0 {
668 return Err(WebSocketError::UnexpectedEOF);
669 }
670 }};
671 }
672
673 while self.buffer.remaining() < 2 {
675 eof!(stream.read_buf(&mut self.buffer).await?);
676 }
677
678 let fin = self.buffer[0] & 0b10000000 != 0;
679 let rsv1 = self.buffer[0] & 0b01000000 != 0;
680 let rsv2 = self.buffer[0] & 0b00100000 != 0;
681 let rsv3 = self.buffer[0] & 0b00010000 != 0;
682
683 if rsv1 || rsv2 || rsv3 {
684 return Err(WebSocketError::ReservedBitsNotZero);
685 }
686
687 let opcode = frame::OpCode::try_from(self.buffer[0] & 0b00001111)?;
688 let masked = self.buffer[1] & 0b10000000 != 0;
689
690 let length_code = self.buffer[1] & 0x7F;
691 let extra = match length_code {
692 126 => 2,
693 127 => 8,
694 _ => 0,
695 };
696
697 self.buffer.advance(2);
698 while self.buffer.remaining() < extra + masked as usize * 4 {
699 eof!(stream.read_buf(&mut self.buffer).await?);
700 }
701
702 let payload_len: usize = match extra {
703 0 => usize::from(length_code),
704 2 => self.buffer.get_u16() as usize,
705 #[cfg(target_pointer_width = "64")]
706 8 => self.buffer.get_u64() as usize,
707 #[cfg(any(target_pointer_width = "16", target_pointer_width = "32"))]
709 8 => match usize::try_from(self.buffer.get_u64()) {
710 Ok(length) => length,
711 Err(_) => return Err(WebSocketError::FrameTooLarge),
712 },
713 _ => unreachable!(),
714 };
715
716 let mask = if masked {
717 Some(self.buffer.get_u32().to_be_bytes())
718 } else {
719 None
720 };
721
722 if frame::is_control(opcode) && !fin {
723 return Err(WebSocketError::ControlFrameFragmented);
724 }
725
726 if opcode == OpCode::Ping && payload_len > 125 {
727 return Err(WebSocketError::PingFrameTooLarge);
728 }
729
730 if payload_len >= self.max_message_size {
731 return Err(WebSocketError::FrameTooLarge);
732 }
733
734 self.buffer.reserve(payload_len + MAX_HEADER_SIZE);
736 while payload_len > self.buffer.remaining() {
737 eof!(stream.read_buf(&mut self.buffer).await?);
738 }
739
740 let payload = self.buffer.split_to(payload_len);
742 let frame = Frame::new(fin, opcode, mask, Payload::Bytes(payload));
743 Ok(frame)
744 }
745}
746
747impl WriteHalf {
748 pub fn after_handshake(role: Role) -> Self {
749 Self {
750 role,
751 closed: false,
752 auto_apply_mask: true,
753 vectored: true,
754 writev_threshold: 1024,
755 write_buffer: Vec::with_capacity(2),
756 }
757 }
758
759 pub async fn write_frame<'a, S>(
761 &'a mut self,
762 stream: &mut S,
763 mut frame: Frame<'a>,
764 ) -> Result<(), WebSocketError>
765 where
766 S: AsyncWrite + Unpin,
767 {
768 if self.role == Role::Client && self.auto_apply_mask {
769 frame.mask();
770 }
771
772 if frame.opcode == OpCode::Close {
773 self.closed = true;
774 } else if self.closed {
775 return Err(WebSocketError::ConnectionClosed);
776 }
777
778 if self.vectored && frame.payload.len() > self.writev_threshold {
779 frame.writev(stream).await?;
780 } else {
781 let text = frame.write(&mut self.write_buffer);
782 stream.write_all(text).await?;
783 }
784
785 Ok(())
786 }
787}
788
789#[cfg(test)]
790mod tests {
791 use super::*;
792
793 const _: () = {
794 const fn assert_unsync<S>() {
795 trait AmbiguousIfImpl<A> {
797 fn some_item() {}
799 }
800
801 impl<T: ?Sized> AmbiguousIfImpl<()> for T {}
802
803 #[allow(dead_code)]
806 struct Invalid;
807
808 impl<T: ?Sized + Sync> AmbiguousIfImpl<Invalid> for T {}
809
810 let _ = <S as AmbiguousIfImpl<_>>::some_item;
814 }
815 assert_unsync::<WebSocket<tokio::net::TcpStream>>();
816 };
817}