use std::{
future::{Future, poll_fn},
io,
pin::{Pin, pin},
task::{Context, Poll},
};
use bytes::Bytes;
use proto::{ClosedStream, ConnectionError, FinishError, StreamId, Written};
use thiserror::Error;
use crate::{
VarInt,
connection::{ConnectionRef, State},
};
#[derive(Debug)]
pub struct SendStream {
conn: ConnectionRef,
stream: StreamId,
is_0rtt: bool,
}
impl SendStream {
pub(crate) fn new(conn: ConnectionRef, stream: StreamId, is_0rtt: bool) -> Self {
Self {
conn,
stream,
is_0rtt,
}
}
pub async fn write(&mut self, buf: &[u8]) -> Result<usize, WriteError> {
poll_fn(|cx| self.execute_poll(cx, |s| s.write(buf))).await
}
pub async fn write_all(&mut self, mut buf: &[u8]) -> Result<(), WriteError> {
while !buf.is_empty() {
let written = self.write(buf).await?;
buf = &buf[written..];
}
Ok(())
}
pub async fn write_chunks(&mut self, bufs: &mut [Bytes]) -> Result<Written, WriteError> {
poll_fn(|cx| self.execute_poll(cx, |s| s.write_chunks(bufs))).await
}
pub async fn write_chunk(&mut self, buf: Bytes) -> Result<(), WriteError> {
self.write_all_chunks(&mut [buf]).await?;
Ok(())
}
pub async fn write_all_chunks(&mut self, mut bufs: &mut [Bytes]) -> Result<(), WriteError> {
while !bufs.is_empty() {
let written = self.write_chunks(bufs).await?;
bufs = &mut bufs[written.chunks..];
}
Ok(())
}
fn execute_poll<F, R>(&mut self, cx: &mut Context, write_fn: F) -> Poll<Result<R, WriteError>>
where
F: FnOnce(&mut proto::SendStream) -> Result<R, proto::WriteError>,
{
use proto::WriteError::*;
let mut conn = self.conn.state.lock("SendStream::poll_write");
if self.is_0rtt {
conn.check_0rtt()
.map_err(|()| WriteError::ZeroRttRejected)?;
}
if let Some(ref x) = conn.error {
return Poll::Ready(Err(WriteError::ConnectionLost(x.clone())));
}
let result = match write_fn(&mut conn.inner.send_stream(self.stream)) {
Ok(result) => result,
Err(Blocked) => {
conn.blocked_writers.insert(self.stream, cx.waker().clone());
return Poll::Pending;
}
Err(Stopped(error_code)) => {
return Poll::Ready(Err(WriteError::Stopped(error_code)));
}
Err(ClosedStream) => {
return Poll::Ready(Err(WriteError::ClosedStream));
}
};
conn.wake();
Poll::Ready(Ok(result))
}
pub fn finish(&mut self) -> Result<(), ClosedStream> {
let mut conn = self.conn.state.lock("finish");
match conn.inner.send_stream(self.stream).finish() {
Ok(()) => {
conn.wake();
Ok(())
}
Err(FinishError::ClosedStream) => Err(ClosedStream::default()),
Err(FinishError::Stopped(_)) => Ok(()),
}
}
pub fn reset(&mut self, error_code: VarInt) -> Result<(), ClosedStream> {
let mut conn = self.conn.state.lock("SendStream::reset");
if self.is_0rtt && conn.check_0rtt().is_err() {
return Ok(());
}
conn.inner.send_stream(self.stream).reset(error_code)?;
conn.wake();
Ok(())
}
pub fn set_priority(&self, priority: i32) -> Result<(), ClosedStream> {
let mut conn = self.conn.state.lock("SendStream::set_priority");
conn.inner.send_stream(self.stream).set_priority(priority)?;
Ok(())
}
pub fn priority(&self) -> Result<i32, ClosedStream> {
let mut conn = self.conn.state.lock("SendStream::priority");
conn.inner.send_stream(self.stream).priority()
}
pub fn stopped(
&self,
) -> impl Future<Output = Result<Option<VarInt>, StoppedError>> + Send + Sync + 'static {
let conn = self.conn.clone();
let stream = self.stream;
let is_0rtt = self.is_0rtt;
async move {
loop {
let notify;
{
let mut conn = conn.state.lock("SendStream::stopped");
if let Some(output) = send_stream_stopped(&mut conn, stream, is_0rtt) {
return output;
}
notify = conn.stopped.entry(stream).or_default().clone();
notify.notified()
}
.await
}
}
}
pub fn id(&self) -> StreamId {
self.stream
}
pub fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &[u8],
) -> Poll<Result<usize, WriteError>> {
pin!(self.get_mut().write(buf)).as_mut().poll(cx)
}
}
fn send_stream_stopped(
conn: &mut State,
stream: StreamId,
is_0rtt: bool,
) -> Option<Result<Option<VarInt>, StoppedError>> {
if is_0rtt && conn.check_0rtt().is_err() {
return Some(Err(StoppedError::ZeroRttRejected));
}
match conn.inner.send_stream(stream).stopped() {
Err(ClosedStream { .. }) => Some(Ok(None)),
Ok(Some(error_code)) => Some(Ok(Some(error_code))),
Ok(None) => conn.error.clone().map(|error| Err(error.into())),
}
}
#[cfg(feature = "futures-io")]
impl futures_io::AsyncWrite for SendStream {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
self.poll_write(cx, buf).map_err(Into::into)
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_close(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
Poll::Ready(self.get_mut().finish().map_err(Into::into))
}
}
impl tokio::io::AsyncWrite for SendStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.poll_write(cx, buf).map_err(Into::into)
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context) -> Poll<io::Result<()>> {
Poll::Ready(self.get_mut().finish().map_err(Into::into))
}
}
impl Drop for SendStream {
fn drop(&mut self) {
let mut conn = self.conn.state.lock("SendStream::drop");
conn.blocked_writers.remove(&self.stream);
if conn.error.is_some() || (self.is_0rtt && conn.check_0rtt().is_err()) {
return;
}
match conn.inner.send_stream(self.stream).finish() {
Ok(()) => conn.wake(),
Err(FinishError::Stopped(reason)) => {
if conn.inner.send_stream(self.stream).reset(reason).is_ok() {
conn.wake();
}
}
Err(FinishError::ClosedStream) => {}
}
}
}
#[derive(Debug, Error, Clone, PartialEq, Eq)]
pub enum WriteError {
#[error("sending stopped by peer: error {0}")]
Stopped(VarInt),
#[error("connection lost")]
ConnectionLost(#[from] ConnectionError),
#[error("closed stream")]
ClosedStream,
#[error("0-RTT rejected")]
ZeroRttRejected,
}
impl From<ClosedStream> for WriteError {
#[inline]
fn from(_: ClosedStream) -> Self {
Self::ClosedStream
}
}
impl From<StoppedError> for WriteError {
fn from(x: StoppedError) -> Self {
match x {
StoppedError::ConnectionLost(e) => Self::ConnectionLost(e),
StoppedError::ZeroRttRejected => Self::ZeroRttRejected,
}
}
}
impl From<WriteError> for io::Error {
fn from(x: WriteError) -> Self {
use WriteError::*;
let kind = match x {
Stopped(_) | ZeroRttRejected => io::ErrorKind::ConnectionReset,
ConnectionLost(_) | ClosedStream => io::ErrorKind::NotConnected,
};
Self::new(kind, x)
}
}
#[derive(Debug, Error, Clone, PartialEq, Eq)]
pub enum StoppedError {
#[error("connection lost")]
ConnectionLost(#[from] ConnectionError),
#[error("0-RTT rejected")]
ZeroRttRejected,
}
impl From<StoppedError> for io::Error {
fn from(x: StoppedError) -> Self {
use StoppedError::*;
let kind = match x {
ZeroRttRejected => io::ErrorKind::ConnectionReset,
ConnectionLost(_) => io::ErrorKind::NotConnected,
};
Self::new(kind, x)
}
}