use bytes::{Bytes, BytesMut};
use futures::{
channel::mpsc::{Receiver, Sender, UnboundedSender},
stream::FusedStream,
task::Waker,
Stream,
};
use std::{
io,
pin::Pin,
task::{Context, Poll},
};
use log::debug;
use tokio::prelude::{AsyncRead, AsyncWrite};
use crate::{
error::Error,
frame::{Flag, Flags, Frame, Type},
StreamId,
};
#[derive(Debug)]
pub struct StreamHandle {
id: StreamId,
state: StreamState,
max_recv_window: u32,
recv_window: u32,
send_window: u32,
read_buf: BytesMut,
event_sender: Sender<StreamEvent>,
unbound_event_sender: UnboundedSender<StreamEvent>,
frame_receiver: Receiver<Frame>,
writeable_wake: Option<Waker>,
}
impl StreamHandle {
pub(crate) fn new(
id: StreamId,
event_sender: Sender<StreamEvent>,
unbound_event_sender: UnboundedSender<StreamEvent>,
frame_receiver: Receiver<Frame>,
state: StreamState,
recv_window_size: u32,
send_window_size: u32,
) -> StreamHandle {
assert!(state == StreamState::Init || state == StreamState::SynReceived);
StreamHandle {
id,
state,
max_recv_window: recv_window_size,
recv_window: recv_window_size,
send_window: send_window_size,
read_buf: BytesMut::default(),
event_sender,
unbound_event_sender,
frame_receiver,
writeable_wake: None,
}
}
pub fn id(&self) -> StreamId {
self.id
}
pub fn state(&self) -> StreamState {
self.state
}
pub fn recv_window(&self) -> u32 {
self.recv_window
}
pub fn send_window(&self) -> u32 {
self.send_window
}
fn close(&mut self) -> Result<(), Error> {
match self.state {
StreamState::SynSent
| StreamState::SynReceived
| StreamState::Established
| StreamState::Init => {
self.state = StreamState::LocalClosing;
self.send_close()?;
}
StreamState::RemoteClosing => {
self.state = StreamState::Closed;
self.send_close()?;
let event = StreamEvent::StateChanged((self.id, self.state));
self.unbound_send_event(event)?;
}
StreamState::Reset | StreamState::Closed => {
self.state = StreamState::Closed;
let event = StreamEvent::StateChanged((self.id, self.state));
self.unbound_send_event(event)?;
}
StreamState::LocalClosing => {
self.state = StreamState::Closed;
let event = StreamEvent::StateChanged((self.id, self.state));
self.unbound_send_event(event)?;
}
}
Ok(())
}
fn send_go_away(&mut self) {
self.state = StreamState::LocalClosing;
let _ignore = self
.unbound_event_sender
.unbounded_send(StreamEvent::GoAway);
}
#[inline]
fn send_event(&mut self, cx: &mut Context, event: StreamEvent) -> Result<(), Error> {
match self.event_sender.poll_ready(cx) {
Poll::Ready(Ok(())) => {
if let Err(e) = self.event_sender.try_send(event) {
if e.is_full() {
return Err(Error::WouldBlock);
} else {
return Err(Error::SessionShutdown);
}
}
}
Poll::Pending => return Err(Error::WouldBlock),
Poll::Ready(Err(_)) => return Err(Error::SessionShutdown),
}
Ok(())
}
fn unbound_send_event(&mut self, event: StreamEvent) -> Result<(), Error> {
self.unbound_event_sender
.unbounded_send(event)
.map_err(|_| Error::SessionShutdown)
}
#[inline]
fn send_frame(&mut self, cx: &mut Context, frame: Frame) -> Result<(), Error> {
let event = StreamEvent::Frame(frame);
self.send_event(cx, event)
}
#[inline]
fn unbound_send_frame(&mut self, frame: Frame) -> Result<(), Error> {
let event = StreamEvent::Frame(frame);
self.unbound_send_event(event)
}
pub(crate) fn send_window_update(&mut self) -> Result<(), Error> {
let buf_len = self.read_buf.len() as u32;
let delta = self.max_recv_window - buf_len - self.recv_window;
let flags = self.get_flags();
if delta < (self.max_recv_window / 2) && flags.value() == 0 {
return Ok(());
}
self.recv_window += delta;
let frame = Frame::new_window_update(flags, self.id, delta);
self.unbound_event_sender
.unbounded_send(StreamEvent::Frame(frame))
.map_err(|_| Error::SessionShutdown)
}
fn send_data(&mut self, cx: &mut Context, data: &[u8]) -> Result<(), Error> {
let flags = self.get_flags();
let frame = Frame::new_data(flags, self.id, Bytes::from(data.to_owned()));
self.send_frame(cx, frame)
}
fn send_close(&mut self) -> Result<(), Error> {
let mut flags = self.get_flags();
flags.add(Flag::Fin);
let frame = Frame::new_window_update(flags, self.id, 0);
self.unbound_send_frame(frame)
}
fn process_flags(&mut self, flags: Flags) -> Result<(), Error> {
if flags.contains(Flag::Ack) && self.state == StreamState::SynSent {
self.state = StreamState::SynReceived;
}
let mut close_stream = false;
if flags.contains(Flag::Fin) {
match self.state {
StreamState::Init
| StreamState::SynSent
| StreamState::SynReceived
| StreamState::Established => {
self.state = StreamState::RemoteClosing;
}
StreamState::LocalClosing => {
self.state = StreamState::Closed;
close_stream = true;
}
_ => return Err(Error::UnexpectedFlag),
}
}
if flags.contains(Flag::Rst) {
self.state = StreamState::Reset;
close_stream = true;
}
if close_stream {
self.close()?;
}
Ok(())
}
fn get_flags(&mut self) -> Flags {
match self.state {
StreamState::Init => {
self.state = StreamState::SynSent;
Flags::from(Flag::Syn)
}
StreamState::SynReceived => {
self.state = StreamState::Established;
Flags::from(Flag::Ack)
}
_ => Flags::default(),
}
}
fn handle_frame(&mut self, frame: Frame) -> Result<(), Error> {
match frame.ty() {
Type::WindowUpdate => {
self.handle_window_update(&frame)?;
}
Type::Data => {
self.handle_data(frame)?;
}
_ => {
return Err(Error::InvalidMsgType);
}
}
Ok(())
}
fn handle_window_update(&mut self, frame: &Frame) -> Result<(), Error> {
self.process_flags(frame.flags())?;
self.send_window = self
.send_window
.checked_add(frame.length())
.ok_or(Error::InvalidMsgType)?;
if let Some(waker) = self.writeable_wake.take() {
waker.wake()
}
Ok(())
}
fn handle_data(&mut self, frame: Frame) -> Result<(), Error> {
self.process_flags(frame.flags())?;
let length = frame.length();
if length > self.recv_window {
return Err(Error::RecvWindowExceeded);
}
let (_, body) = frame.into_parts();
if let Some(data) = body {
self.read_buf.extend_from_slice(&data);
}
self.recv_window -= length;
Ok(())
}
fn recv_frames(&mut self, cx: &mut Context) -> Result<(), Error> {
loop {
match self.state {
StreamState::RemoteClosing => {
return Err(Error::SubStreamRemoteClosing);
}
StreamState::Reset | StreamState::Closed => {
return Err(Error::SessionShutdown);
}
_ => {}
}
if !self.read_buf.is_empty() {
break;
}
if self.frame_receiver.is_terminated() {
self.state = StreamState::RemoteClosing;
return Err(Error::SessionShutdown);
}
match Pin::new(&mut self.frame_receiver).as_mut().poll_next(cx) {
Poll::Ready(Some(frame)) => self.handle_frame(frame)?,
Poll::Ready(None) => {
self.state = StreamState::RemoteClosing;
return Err(Error::SessionShutdown);
}
Poll::Pending => break,
}
}
Ok(())
}
fn check_self_state(&mut self) -> Result<(), io::Error> {
if self.read_buf.is_empty() {
match self.state {
StreamState::RemoteClosing => {
debug!("closed(EOF)");
let _ignore = self.send_close();
Err(io::ErrorKind::UnexpectedEof.into())
}
StreamState::Reset => {
debug!("connection reset");
let _ignore = self.send_close();
Err(io::ErrorKind::ConnectionReset.into())
}
StreamState::Closed => Err(io::ErrorKind::BrokenPipe.into()),
_ => Ok(()),
}
} else {
Ok(())
}
}
}
impl AsyncRead for StreamHandle {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
self.check_self_state()?;
if let Err(e) = self.recv_frames(cx) {
match e {
Error::UnexpectedFlag | Error::RecvWindowExceeded | Error::InvalidMsgType => {
self.send_go_away();
return Poll::Ready(Err(io::ErrorKind::InvalidData.into()));
}
_ => (),
}
}
self.check_self_state()?;
let n = ::std::cmp::min(buf.len(), self.read_buf.len());
if n == 0 {
return Poll::Pending;
}
let b = self.read_buf.split_to(n);
buf[..n].copy_from_slice(&b);
match self.state {
StreamState::RemoteClosing | StreamState::Closed | StreamState::Reset => (),
StreamState::LocalClosing => {
if self.close().is_err() {
return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()));
}
}
_ => {
if self.send_window_update().is_err() {
return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()));
}
}
}
Poll::Ready(Ok(n))
}
unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [std::mem::MaybeUninit<u8>]) -> bool {
false
}
}
impl AsyncWrite for StreamHandle {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context,
buf: &[u8],
) -> Poll<io::Result<usize>> {
match self.state {
StreamState::RemoteClosing | StreamState::Reset => {
return Poll::Ready(Err(io::ErrorKind::BrokenPipe.into()))
}
StreamState::LocalClosing | StreamState::Closed => {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::BrokenPipe,
"The local is closed and data cannot be written.",
)));
}
_ => (),
}
if self.send_window == 0 {
self.writeable_wake = Some(cx.waker().clone());
return Poll::Pending;
}
let n = ::std::cmp::min(self.send_window as usize, buf.len());
let data = &buf[0..n];
match self.send_data(cx, data) {
Ok(_) => {
self.send_window -= n as u32;
Poll::Ready(Ok(n))
}
Err(Error::WouldBlock) => Poll::Pending,
_ => Poll::Ready(Err(io::ErrorKind::BrokenPipe.into())),
}
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(mut self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
debug!("[{}] StreamHandle.shutdown()", self.id);
match self.close() {
Err(_) => Poll::Ready(Err(io::ErrorKind::BrokenPipe.into())),
Ok(()) => Poll::Ready(Ok(())),
}
}
}
impl Drop for StreamHandle {
fn drop(&mut self) {
if !self.event_sender.is_closed()
&& self.state != StreamState::Closed
&& self.state != StreamState::LocalClosing
{
let mut flags = self.get_flags();
flags.add(Flag::Rst);
let frame = Frame::new_window_update(flags, self.id, 0);
let rst_event = StreamEvent::Frame(frame);
let event = StreamEvent::StateChanged((self.id, StreamState::Closed));
let _ignore = self.unbound_event_sender.unbounded_send(rst_event);
let _ignore = self.unbound_event_sender.unbounded_send(event);
}
}
}
#[derive(Debug, Eq, PartialEq)]
pub(crate) enum StreamEvent {
Frame(Frame),
StateChanged((StreamId, StreamState)),
GoAway,
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum StreamState {
Init,
SynSent,
SynReceived,
Established,
LocalClosing,
RemoteClosing,
Closed,
Reset,
}
#[cfg(test)]
mod test {
use super::{StreamEvent, StreamHandle, StreamState};
use crate::{
config::INITIAL_STREAM_WINDOW,
frame::{Flag, Flags, Frame},
};
use bytes::Bytes;
use futures::{
channel::mpsc::{channel, unbounded},
SinkExt, StreamExt,
};
use std::io::ErrorKind;
use tokio::io::AsyncReadExt;
#[test]
fn test_drop() {
let mut rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let (event_sender, _event_receiver) = channel(2);
let (_frame_sender, frame_receiver) = channel(2);
let (unbound_sender, mut unbound_receiver) = unbounded();
let stream = StreamHandle::new(
0,
event_sender,
unbound_sender,
frame_receiver,
StreamState::Init,
INITIAL_STREAM_WINDOW,
INITIAL_STREAM_WINDOW,
);
drop(stream);
let event = unbound_receiver.next().await.unwrap();
match event {
StreamEvent::Frame(frame) => assert!(frame.flags().contains(Flag::Rst)),
_ => panic!("must be a frame msg contain RST"),
}
let event = unbound_receiver.next().await.unwrap();
match event {
StreamEvent::StateChanged((_, state)) => assert_eq!(state, StreamState::Closed),
_ => panic!("must be state change"),
}
});
}
#[test]
fn test_drop_with_state_reset() {
let mut rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let (event_sender, _event_receiver) = channel(2);
let (mut frame_sender, frame_receiver) = channel(2);
let (unbound_sender, mut unbound_receiver) = unbounded();
let mut stream = StreamHandle::new(
0,
event_sender,
unbound_sender,
frame_receiver,
StreamState::Init,
INITIAL_STREAM_WINDOW,
INITIAL_STREAM_WINDOW,
);
let mut flags = Flags::from(Flag::Syn);
flags.add(Flag::Rst);
let frame = Frame::new_window_update(flags, 0, 0);
frame_sender.send(frame).await.unwrap();
let mut b = [0; 1024];
assert_eq!(
stream.read(&mut b).await.unwrap_err().kind(),
ErrorKind::BrokenPipe
);
drop(stream);
let event = unbound_receiver.next().await.unwrap();
match event {
StreamEvent::StateChanged((_, state)) => assert_eq!(state, StreamState::Closed),
_ => panic!("must be state change"),
}
});
}
#[test]
fn test_data_large_than_recv_window() {
let mut rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let (event_sender, _event_receiver) = channel(2);
let (mut frame_sender, frame_receiver) = channel(2);
let (unbound_sender, mut unbound_receiver) = unbounded();
let mut stream = StreamHandle::new(
0,
event_sender,
unbound_sender,
frame_receiver,
StreamState::Init,
2,
INITIAL_STREAM_WINDOW,
);
let flags = Flags::from(Flag::Syn);
let frame = Frame::new_data(flags, 0, Bytes::from("1234"));
frame_sender.send(frame).await.unwrap();
let mut b = [0; 1024];
assert_eq!(
stream.read(&mut b).await.unwrap_err().kind(),
ErrorKind::InvalidData
);
let event = unbound_receiver.next().await.unwrap();
match event {
StreamEvent::GoAway => (),
_ => panic!("must be go away"),
}
});
}
}