async-ws 0.4.0

async websocket implementation
Documentation
use crate::connection::encode::EncodeState::Sending;
use crate::connection::WsConnectionError;
use crate::frame::{
    payload_mask, FrameHead, WsControlFrame, WsControlFrameKind, WsDataFrameKind, WsFrameKind,
};
use crate::message::WsMessageKind;
use futures::task::{Context, Poll};
use futures::{io, AsyncRead, AsyncWrite};
use rand::{thread_rng, RngCore};
use std::mem::replace;
use std::pin::Pin;

#[derive(Debug)]
pub(crate) enum EncodeState {
    Sending {
        frame_in_progress: Option<FrameInProgress>,
        next_data_frame_kind: Option<WsDataFrameKind>,
        queued_control: Option<WsControlFrame>,
        flushed: Option<bool>,
        closing: bool,
    },
    Err(io::Error),
    Done,
}

#[derive(Debug)]
pub(crate) enum EncodeReady {
    Buffering,
    FlushedMessages,
    FlushedFrames,
    Error,
    Done,
}

impl EncodeState {
    pub fn new() -> EncodeState {
        EncodeState::Sending {
            frame_in_progress: None,
            next_data_frame_kind: None,
            queued_control: None,
            flushed: Some(false),
            closing: false,
        }
    }
    pub fn start_message(&mut self, kind: WsMessageKind) {
        if let Sending {
            next_data_frame_kind,
            frame_in_progress: None,
            flushed,
            ..
        } = self
        {
            assert!(next_data_frame_kind.is_none());
            flushed.take();
            *next_data_frame_kind = Some(kind.frame_kind());
        } else {
            unreachable!()
        }
    }
    pub fn end_message(&mut self, mask: bool) {
        if let Sending {
            next_data_frame_kind,
            frame_in_progress,
            ..
        } = self
        {
            if let Some(kind) = next_data_frame_kind.take() {
                if let Some(frame) = frame_in_progress {
                    frame.start_writing(true);
                } else {
                    let mut frame = FrameInProgress::new(kind.frame_kind(), mask);
                    frame.start_writing(true);
                    *frame_in_progress = Some(frame);
                }
                self.start_flushing()
            }
        }
    }
    pub fn queue_control(&mut self, control: WsControlFrame) {
        if let Sending { queued_control, .. } = self {
            if let Some(queued) = queued_control {
                if queued.kind() == WsControlFrameKind::Close {
                    return;
                }
            }
            *queued_control = Some(control);
        }
    }
    pub fn append_data(&mut self, buf: &[u8], mask: bool) -> usize {
        if let Sending {
            next_data_frame_kind,
            frame_in_progress,
            flushed,
            ..
        } = self
        {
            let kind = next_data_frame_kind
                .replace(WsDataFrameKind::Continuation)
                .unwrap();
            flushed.take();
            frame_in_progress
                .get_or_insert_with(|| FrameInProgress::new(kind.frame_kind(), mask))
                .append_data(buf)
        } else {
            unreachable!()
        }
    }
    pub fn start_flushing(&mut self) {
        if let Sending { flushed, .. } = self {
            flushed.get_or_insert(false);
        } else {
            unreachable!()
        }
    }
    pub fn take_err(&mut self) -> Option<WsConnectionError> {
        if let EncodeState::Err(_) = self {
            let old = replace(self, EncodeState::Done);
            match old {
                EncodeState::Err(err) => return Some(err.into()),
                _ => unreachable!(),
            }
        }
        None
    }
    pub fn poll<T: AsyncRead + AsyncWrite + Unpin>(
        &mut self,
        transport: &mut T,
        cx: &mut Context<'_>,
        mask: bool,
    ) -> Poll<EncodeReady> {
        loop {
            match self {
                Self::Sending {
                    frame_in_progress,
                    next_data_frame_kind,
                    queued_control,
                    flushed,
                    closing,
                } => {
                    if let Some(frame) = frame_in_progress {
                        match frame.poll(transport, cx) {
                            Poll::Ready(FrameInProgressReady::Buffering) => {
                                if flushed.is_none() && queued_control.is_none() {
                                    return Poll::Ready(EncodeReady::Buffering);
                                }
                                frame.start_writing(false);
                            }
                            Poll::Ready(FrameInProgressReady::Written) => *frame_in_progress = None,
                            Poll::Ready(FrameInProgressReady::Err(err)) => *self = Self::Err(err),
                            Poll::Pending => return Poll::Pending,
                        }
                    } else if let Some(control) = queued_control.take() {
                        *closing |= control.kind() == WsControlFrameKind::Close;
                        *frame_in_progress = Some(FrameInProgress::new_control(control, mask));
                        self.start_flushing();
                    } else {
                        match (*flushed, closing) {
                            (None, false) => match next_data_frame_kind {
                                Some(_) => return Poll::Ready(EncodeReady::Buffering),
                                None => unreachable!("should always flush after message fin"),
                            },
                            (None, true) => unreachable!("should always flush after close frame"),
                            (Some(true), true) => *self = Self::Done,
                            (Some(true), false) => match next_data_frame_kind {
                                Some(_) => return Poll::Ready(EncodeReady::FlushedFrames),
                                None => return Poll::Ready(EncodeReady::FlushedMessages),
                            },
                            (Some(false), _) => match Pin::new(&mut *transport).poll_flush(cx) {
                                Poll::Ready(Ok(())) => *flushed = Some(true),
                                Poll::Ready(Err(err)) => *self = Self::Err(err),
                                Poll::Pending => return Poll::Pending,
                            },
                        }
                    }
                }
                Self::Err(_) => return Poll::Ready(EncodeReady::Error),
                Self::Done => return Poll::Ready(EncodeReady::Done),
            }
        }
    }
}

#[derive(Debug)]
pub struct FrameInProgress {
    kind: WsFrameKind,
    buffer: [u8; 1300],
    mask: [u8; 4],
    written: Option<usize>,
    filled: usize,
}

#[derive(Debug)]
enum FrameInProgressReady {
    Buffering,
    Written,
    Err(io::Error),
}

const FRAME_BUFFER_PAYLOAD_OFFSET: usize = 8;

impl FrameInProgress {
    fn new(kind: WsFrameKind, mask: bool) -> Self {
        FrameInProgress {
            kind,
            buffer: [0u8; 1300],
            mask: match mask {
                true => loop {
                    let r = thread_rng().next_u32();
                    if r != 0 {
                        break r.to_ne_bytes();
                    }
                },
                false => [0u8, 0u8, 0u8, 0u8],
            },
            written: None,
            filled: FRAME_BUFFER_PAYLOAD_OFFSET,
        }
    }
    fn new_control(control: WsControlFrame, mask: bool) -> Self {
        let mut frame = Self::new(control.kind().frame_kind(), mask);
        frame.append_data(control.payload());
        frame.start_writing(true);
        frame
    }
    fn append_data(&mut self, buf: &[u8]) -> usize {
        assert!(self.written.is_none());
        let payload_len = self.buffer.len() - self.filled;
        let append = buf.len().min(payload_len);
        let target_slice = &mut self.buffer[self.filled..self.filled + append];
        target_slice.copy_from_slice(&buf[..append]);
        payload_mask(self.mask, payload_len, target_slice);
        self.filled += append;
        if append < buf.len() {
            self.start_writing(false)
        }
        append
    }
    fn start_writing(&mut self, fin: bool) {
        assert!(self.written.is_none());
        let head = FrameHead {
            fin,
            opcode: self.kind.opcode(),
            mask: self.mask,
            payload_len: (self.filled - FRAME_BUFFER_PAYLOAD_OFFSET) as u64,
        };
        let offset = FRAME_BUFFER_PAYLOAD_OFFSET - head.len_bytes();
        head.encode(&mut self.buffer[offset..FRAME_BUFFER_PAYLOAD_OFFSET]);
        self.written = Some(offset);
    }
    fn poll<T: AsyncRead + AsyncWrite + Unpin>(
        &mut self,
        transport: &mut T,
        cx: &mut Context<'_>,
    ) -> Poll<FrameInProgressReady> {
        let mut offset = match self.written {
            None => return Poll::Ready(FrameInProgressReady::Buffering),
            Some(offset) => offset,
        };
        loop {
            match Pin::new(&mut *transport).poll_write(cx, &self.buffer[offset..self.filled]) {
                Poll::Ready(Ok(n)) => {
                    offset += n;
                    self.written = Some(offset);
                    if offset == self.filled {
                        return Poll::Ready(FrameInProgressReady::Written);
                    }
                }
                Poll::Ready(Err(err)) => return Poll::Ready(FrameInProgressReady::Err(err)),
                Poll::Pending => return Poll::Pending,
            }
        }
    }
}