use std::{
future::Future,
io,
mem::MaybeUninit,
pin::Pin,
str,
task::{Context, Poll},
};
use bytes::Bytes;
use err_derive::Error;
use futures::{
channel::oneshot,
io::{AsyncRead, AsyncWrite},
ready, FutureExt,
};
use proto::{ConnectionError, StreamId};
use crate::{connection::ConnectionRef, VarInt};
#[derive(Debug)]
pub struct SendStream<S>
where
S: proto::crypto::Session,
{
conn: ConnectionRef<S>,
stream: StreamId,
is_0rtt: bool,
finishing: Option<oneshot::Receiver<Option<WriteError>>>,
}
impl<S> SendStream<S>
where
S: proto::crypto::Session,
{
pub(crate) fn new(conn: ConnectionRef<S>, stream: StreamId, is_0rtt: bool) -> Self {
Self {
conn,
stream,
is_0rtt,
finishing: None,
}
}
pub fn write<'a>(&'a mut self, buf: &'a [u8]) -> Write<'a, S> {
Write { stream: self, buf }
}
pub fn write_all<'a>(&'a mut self, buf: &'a [u8]) -> WriteAll<'a, S> {
WriteAll { stream: self, buf }
}
fn poll_write(&mut self, cx: &mut Context, buf: &[u8]) -> Poll<Result<usize, WriteError>> {
use proto::WriteError::*;
let mut conn = self.conn.lock().unwrap();
if self.is_0rtt {
conn.check_0rtt()
.map_err(|()| WriteError::ZeroRttRejected)?;
}
if let Some(ref x) = conn.error {
return Poll::Ready(Err(WriteError::ConnectionClosed(x.clone())));
}
let n = match conn.inner.write(self.stream, buf) {
Ok(n) => n,
Err(Blocked) => {
conn.blocked_writers.insert(self.stream, cx.waker().clone());
return Poll::Pending;
}
Err(Stopped(error_code)) => {
return Poll::Ready(Err(WriteError::Stopped(error_code)));
}
Err(UnknownStream) => {
return Poll::Ready(Err(WriteError::UnknownStream));
}
};
conn.wake();
Poll::Ready(Ok(n))
}
pub fn finish(&mut self) -> Finish<'_, S> {
Finish { stream: self }
}
fn poll_finish(&mut self, cx: &mut Context) -> Poll<Result<(), WriteError>> {
let mut conn = self.conn.lock().unwrap();
if self.is_0rtt {
conn.check_0rtt()
.map_err(|()| WriteError::ZeroRttRejected)?;
}
if self.finishing.is_none() {
conn.inner.finish(self.stream).map_err(|e| match e {
proto::FinishError::UnknownStream => WriteError::UnknownStream,
proto::FinishError::Stopped(error_code) => WriteError::Stopped(error_code),
})?;
let (send, recv) = oneshot::channel();
self.finishing = Some(recv);
conn.finishing.insert(self.stream, send);
conn.wake();
}
match self
.finishing
.as_mut()
.unwrap()
.poll_unpin(cx)
.map(|x| x.unwrap())
{
Poll::Ready(None) => Poll::Ready(Ok(())),
Poll::Ready(Some(e)) => Poll::Ready(Err(e)),
Poll::Pending => {
if let Some(ref x) = conn.error {
return Poll::Ready(Err(WriteError::ConnectionClosed(x.clone())));
}
Poll::Pending
}
}
}
pub fn reset(&mut self, error_code: VarInt) {
let mut conn = self.conn.lock().unwrap();
if self.is_0rtt && conn.check_0rtt().is_err() {
return;
}
conn.inner.reset(self.stream, error_code);
conn.wake();
}
#[doc(hidden)]
pub fn id(&self) -> StreamId {
self.stream
}
}
impl<S> AsyncWrite for SendStream<S>
where
S: proto::crypto::Session,
{
fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
SendStream::poll_write(self.get_mut(), cx, buf).map_err(Into::into)
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
self.get_mut().poll_finish(cx).map_err(Into::into)
}
}
impl<S> tokio::io::AsyncWrite for SendStream<S>
where
S: proto::crypto::Session,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
AsyncWrite::poll_write(self, cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
AsyncWrite::poll_close(self, cx)
}
}
impl<S> Drop for SendStream<S>
where
S: proto::crypto::Session,
{
fn drop(&mut self) {
let mut conn = self.conn.lock().unwrap();
if conn.error.is_some() || (self.is_0rtt && conn.check_0rtt().is_err()) {
return;
}
if self.finishing.is_none() {
if conn.inner.finish(self.stream).is_ok() {
conn.wake();
}
}
}
}
pub struct Finish<'a, S>
where
S: proto::crypto::Session,
{
stream: &'a mut SendStream<S>,
}
impl<S> Future for Finish<'_, S>
where
S: proto::crypto::Session,
{
type Output = Result<(), WriteError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
self.get_mut().stream.poll_finish(cx)
}
}
#[derive(Debug)]
pub struct RecvStream<S>
where
S: proto::crypto::Session,
{
conn: ConnectionRef<S>,
stream: StreamId,
is_0rtt: bool,
all_data_read: bool,
any_data_read: bool,
}
impl<S> RecvStream<S>
where
S: proto::crypto::Session,
{
pub(crate) fn new(conn: ConnectionRef<S>, stream: StreamId, is_0rtt: bool) -> Self {
Self {
conn,
stream,
is_0rtt,
all_data_read: false,
any_data_read: false,
}
}
pub fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> Read<'a, S> {
Read { stream: self, buf }
}
pub fn read_exact<'a>(&'a mut self, buf: &'a mut [u8]) -> ReadExact<'a, S> {
ReadExact {
stream: self,
off: 0,
buf,
}
}
fn poll_read(
&mut self,
cx: &mut Context,
buf: &mut [u8],
) -> Poll<Result<Option<usize>, ReadError>> {
self.any_data_read = true;
use proto::ReadError::*;
let mut conn = self.conn.lock().unwrap();
if self.is_0rtt {
conn.check_0rtt().map_err(|()| ReadError::ZeroRttRejected)?;
}
match conn.inner.read(self.stream, buf) {
Ok(Some(n)) => Poll::Ready(Ok(Some(n))),
Ok(None) => {
self.all_data_read = true;
Poll::Ready(Ok(None))
}
Err(Blocked) => {
if let Some(ref x) = conn.error {
return Poll::Ready(Err(ReadError::ConnectionClosed(x.clone())));
}
conn.blocked_readers.insert(self.stream, cx.waker().clone());
Poll::Pending
}
Err(Reset(error_code)) => {
self.all_data_read = true;
Poll::Ready(Err(ReadError::Reset(error_code)))
}
Err(UnknownStream) => Poll::Ready(Err(ReadError::UnknownStream)),
}
}
pub fn read_unordered(&mut self) -> ReadUnordered<'_, S> {
ReadUnordered { stream: self }
}
fn poll_read_unordered(
&mut self,
cx: &mut Context,
) -> Poll<Result<Option<(Bytes, u64)>, ReadError>> {
self.any_data_read = true;
use proto::ReadError::*;
let mut conn = self.conn.lock().unwrap();
if self.is_0rtt {
conn.check_0rtt().map_err(|()| ReadError::ZeroRttRejected)?;
}
match conn.inner.read_unordered(self.stream) {
Ok(Some((bytes, offset))) => Poll::Ready(Ok(Some((bytes, offset)))),
Ok(None) => {
self.all_data_read = true;
Poll::Ready(Ok(None))
}
Err(Blocked) => {
if let Some(ref x) = conn.error {
return Poll::Ready(Err(ReadError::ConnectionClosed(x.clone())));
}
conn.blocked_readers.insert(self.stream, cx.waker().clone());
Poll::Pending
}
Err(Reset(error_code)) => {
self.all_data_read = true;
Poll::Ready(Err(ReadError::Reset(error_code)))
}
Err(UnknownStream) => Poll::Ready(Err(ReadError::UnknownStream)),
}
}
pub fn read_to_end(self, size_limit: usize) -> ReadToEnd<S> {
ReadToEnd {
stream: self,
size_limit,
read: Vec::new(),
start: u64::max_value(),
end: 0,
}
}
pub fn stop(&mut self, error_code: VarInt) -> Result<(), UnknownStream> {
let mut conn = self.conn.lock().unwrap();
if self.is_0rtt && conn.check_0rtt().is_err() {
return Ok(());
}
conn.inner.stop_sending(self.stream, error_code)?;
conn.wake();
self.all_data_read = true;
Ok(())
}
pub fn is_0rtt(&self) -> bool {
self.is_0rtt
}
}
pub struct ReadToEnd<S>
where
S: proto::crypto::Session,
{
stream: RecvStream<S>,
read: Vec<(Bytes, u64)>,
start: u64,
end: u64,
size_limit: usize,
}
impl<S> Future for ReadToEnd<S>
where
S: proto::crypto::Session,
{
type Output = Result<Vec<u8>, ReadToEndError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
loop {
match ready!(self.stream.poll_read_unordered(cx))? {
Some((data, offset)) => {
self.start = self.start.min(offset);
let end = data.len() as u64 + offset;
if (end - self.start) > self.size_limit as u64 {
return Poll::Ready(Err(ReadToEndError::TooLong));
}
self.end = self.end.max(end);
self.read.push((data, offset));
}
None => {
if self.end == 0 {
return Poll::Ready(Ok(Vec::new()));
}
let start = self.start;
let mut buffer = vec![0; (self.end - start) as usize];
for (data, offset) in self.read.drain(..) {
let offset = (offset - start) as usize;
buffer[offset..offset + data.len()].copy_from_slice(&data);
}
return Poll::Ready(Ok(buffer));
}
}
}
}
}
#[derive(Debug, Error, Clone, PartialEq, Eq)]
pub enum ReadToEndError {
#[error(display = "read error")]
Read(#[source] ReadError),
#[error(display = "stream too long")]
TooLong,
}
impl<S> AsyncRead for RecvStream<S>
where
S: proto::crypto::Session,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
Poll::Ready(Ok(
match ready!(RecvStream::poll_read(self.get_mut(), cx, buf))? {
Some(n) => n,
None => 0,
},
))
}
}
impl<S> tokio::io::AsyncRead for RecvStream<S>
where
S: proto::crypto::Session,
{
unsafe fn prepare_uninitialized_buffer(&self, _: &mut [MaybeUninit<u8>]) -> bool {
false
}
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
AsyncRead::poll_read(self, cx, buf)
}
}
impl<S> Drop for RecvStream<S>
where
S: proto::crypto::Session,
{
fn drop(&mut self) {
let mut conn = self.conn.lock().unwrap();
if conn.error.is_some() || (self.is_0rtt && conn.check_0rtt().is_err()) {
return;
}
if !self.all_data_read {
let _ = conn.inner.stop_sending(self.stream, 0u32.into());
conn.wake();
}
}
}
#[derive(Debug, Error, Clone, PartialEq, Eq)]
pub enum ReadError {
#[error(display = "stream reset by peer: error {}", 0)]
Reset(VarInt),
#[error(display = "connection closed: {}", _0)]
ConnectionClosed(ConnectionError),
#[error(display = "unknown stream")]
UnknownStream,
#[error(display = "0-RTT rejected")]
ZeroRttRejected,
}
impl From<ReadError> for io::Error {
fn from(x: ReadError) -> Self {
use self::ReadError::*;
let kind = match x {
ConnectionClosed(e) => {
return e.into();
}
Reset { .. } | ZeroRttRejected => io::ErrorKind::ConnectionReset,
UnknownStream => io::ErrorKind::NotConnected,
};
io::Error::new(kind, x)
}
}
#[derive(Debug, Error, Clone, PartialEq, Eq)]
pub enum WriteError {
#[error(display = "sending stopped by peer: error {}", 0)]
Stopped(VarInt),
#[error(display = "connection closed: {}", _0)]
ConnectionClosed(ConnectionError),
#[error(display = "unknown stream")]
UnknownStream,
#[error(display = "0-RTT rejected")]
ZeroRttRejected,
}
impl From<WriteError> for io::Error {
fn from(x: WriteError) -> Self {
use self::WriteError::*;
let kind = match x {
ConnectionClosed(e) => {
return e.into();
}
Stopped(_) | ZeroRttRejected => io::ErrorKind::ConnectionReset,
UnknownStream => io::ErrorKind::NotConnected,
};
io::Error::new(kind, x)
}
}
pub struct Read<'a, S>
where
S: proto::crypto::Session,
{
stream: &'a mut RecvStream<S>,
buf: &'a mut [u8],
}
impl<'a, S> Future for Read<'a, S>
where
S: proto::crypto::Session,
{
type Output = Result<Option<usize>, ReadError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = self.get_mut();
this.stream.poll_read(cx, this.buf)
}
}
pub struct ReadExact<'a, S>
where
S: proto::crypto::Session,
{
stream: &'a mut RecvStream<S>,
off: usize,
buf: &'a mut [u8],
}
impl<'a, S> Future for ReadExact<'a, S>
where
S: proto::crypto::Session,
{
type Output = Result<(), ReadExactError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = self.get_mut();
while this.buf.len() != this.off {
let n: usize = ready!(this
.stream
.poll_read(cx, &mut this.buf[this.off..])
.map_err(ReadExactError::ReadError)?)
.ok_or(ReadExactError::FinishedEarly)?;
this.off += n;
}
Poll::Ready(Ok(()))
}
}
#[derive(Debug, Error, Clone, PartialEq, Eq)]
pub enum ReadExactError {
#[error(display = "stream finished early")]
FinishedEarly,
#[error(display = "{}", 0)]
ReadError(ReadError),
}
pub struct ReadUnordered<'a, S>
where
S: proto::crypto::Session,
{
stream: &'a mut RecvStream<S>,
}
impl<'a, S> Future for ReadUnordered<'a, S>
where
S: proto::crypto::Session,
{
type Output = Result<Option<(Bytes, u64)>, ReadError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
self.stream.poll_read_unordered(cx)
}
}
pub struct Write<'a, S>
where
S: proto::crypto::Session,
{
stream: &'a mut SendStream<S>,
buf: &'a [u8],
}
impl<'a, S> Future for Write<'a, S>
where
S: proto::crypto::Session,
{
type Output = Result<usize, WriteError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = self.get_mut();
this.stream.poll_write(cx, this.buf)
}
}
pub struct WriteAll<'a, S>
where
S: proto::crypto::Session,
{
stream: &'a mut SendStream<S>,
buf: &'a [u8],
}
impl<'a, S> Future for WriteAll<'a, S>
where
S: proto::crypto::Session,
{
type Output = Result<(), WriteError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = self.get_mut();
loop {
if this.buf.is_empty() {
return Poll::Ready(Ok(()));
}
let n = ready!(this.stream.poll_write(cx, this.buf))?;
this.buf = &this.buf[n..];
}
}
}
#[derive(Debug)]
pub struct UnknownStream {}
impl From<proto::UnknownStream> for UnknownStream {
fn from(_: proto::UnknownStream) -> Self {
UnknownStream {}
}
}