use std::{
io,
task::{Context, Poll},
};
use compio::buf::{BufResult, IoBuf, bytes::Bytes};
use compio::io::AsyncWrite;
use futures_util::{future::poll_fn, ready};
use noq_proto::{ClosedStream, FinishError, StreamId, VarInt, Written};
use thiserror::Error;
use crate::{ConnectionError, ConnectionInner, sync::shared::Shared};
#[derive(Debug)]
pub struct SendStream {
conn: Shared<ConnectionInner>,
stream: StreamId,
is_0rtt: bool,
}
impl SendStream {
pub(crate) fn new(conn: Shared<ConnectionInner>, stream: StreamId, is_0rtt: bool) -> Self {
Self {
conn,
stream,
is_0rtt,
}
}
pub fn id(&self) -> StreamId {
self.stream
}
pub fn finish(&mut self) -> Result<(), ClosedStream> {
let mut state = self.conn.state();
match state.conn.send_stream(self.stream).finish() {
Ok(()) => {
state.wake();
Ok(())
}
Err(FinishError::ClosedStream) => Err(ClosedStream::default()),
Err(FinishError::Stopped(_)) => Ok(()),
}
}
pub fn reset(&mut self, error_code: VarInt) -> Result<(), ClosedStream> {
let mut state = self.conn.state();
if self.is_0rtt && !state.check_0rtt() {
return Ok(());
}
state.conn.send_stream(self.stream).reset(error_code)?;
state.wake();
Ok(())
}
pub fn set_priority(&self, priority: i32) -> Result<(), ClosedStream> {
self.conn
.state()
.conn
.send_stream(self.stream)
.set_priority(priority)
}
pub fn priority(&self) -> Result<i32, ClosedStream> {
self.conn.state().conn.send_stream(self.stream).priority()
}
pub async fn stopped(&mut self) -> Result<Option<VarInt>, StoppedError> {
poll_fn(|cx| {
let mut state = self.conn.state();
if self.is_0rtt && !state.check_0rtt() {
return Poll::Ready(Err(StoppedError::ZeroRttRejected));
}
match state.conn.send_stream(self.stream).stopped() {
Err(_) => Poll::Ready(Ok(None)),
Ok(Some(error_code)) => Poll::Ready(Ok(Some(error_code))),
Ok(None) => {
if let Some(e) = &state.error {
return Poll::Ready(Err(e.clone().into()));
}
state.stopped.insert(self.stream, cx.waker().clone());
Poll::Pending
}
}
})
.await
}
fn execute_poll_write<F, R>(&mut self, cx: &mut Context, f: F) -> Poll<Result<R, WriteError>>
where
F: FnOnce(noq_proto::SendStream) -> Result<R, noq_proto::WriteError>,
{
let mut state = self.conn.try_state()?;
if self.is_0rtt && !state.check_0rtt() {
return Poll::Ready(Err(WriteError::ZeroRttRejected));
}
match f(state.conn.send_stream(self.stream)) {
Ok(r) => {
state.wake();
Poll::Ready(Ok(r))
}
Err(e) => match e.try_into() {
Ok(e) => Poll::Ready(Err(e)),
Err(()) => {
state.writable.insert(self.stream, cx.waker().clone());
Poll::Pending
}
},
}
}
pub async fn write_chunks(&mut self, bufs: &mut [Bytes]) -> Result<Written, WriteError> {
poll_fn(|cx| self.execute_poll_write(cx, |mut stream| stream.write_chunks(bufs))).await
}
pub async fn write_all_chunks(&mut self, bufs: &mut [Bytes]) -> Result<(), WriteError> {
let mut chunks = 0;
poll_fn(|cx| {
loop {
if chunks == bufs.len() {
return Poll::Ready(Ok(()));
}
let written = ready!(self.execute_poll_write(cx, |mut stream| {
stream.write_chunks(&mut bufs[chunks..])
}))?;
chunks += written.chunks;
}
})
.await
}
#[cfg(feature = "io-compat")]
pub fn into_compat(self) -> CompatSendStream {
CompatSendStream(self)
}
}
impl Drop for SendStream {
fn drop(&mut self) {
let mut state = self.conn.state();
state.stopped.remove(&self.stream);
state.writable.remove(&self.stream);
if state.error.is_some() || (self.is_0rtt && !state.check_0rtt()) {
return;
}
match state.conn.send_stream(self.stream).finish() {
Ok(()) => state.wake(),
Err(FinishError::Stopped(reason)) => {
if state.conn.send_stream(self.stream).reset(reason).is_ok() {
state.wake();
}
}
Err(FinishError::ClosedStream) => {}
}
}
}
#[derive(Debug, Error, Clone, PartialEq, Eq)]
pub enum WriteError {
#[error("sending stopped by peer: error {0}")]
Stopped(VarInt),
#[error("connection lost")]
ConnectionLost(#[from] ConnectionError),
#[error("closed stream")]
ClosedStream,
#[error("0-RTT rejected")]
ZeroRttRejected,
#[cfg(feature = "h3")]
#[error("stream not ready")]
NotReady,
}
impl TryFrom<noq_proto::WriteError> for WriteError {
type Error = ();
fn try_from(value: noq_proto::WriteError) -> Result<Self, Self::Error> {
use noq_proto::WriteError::*;
match value {
Stopped(e) => Ok(Self::Stopped(e)),
ClosedStream => Ok(Self::ClosedStream),
Blocked => Err(()),
}
}
}
impl From<StoppedError> for WriteError {
fn from(x: StoppedError) -> Self {
match x {
StoppedError::ConnectionLost(e) => Self::ConnectionLost(e),
StoppedError::ZeroRttRejected => Self::ZeroRttRejected,
}
}
}
impl From<WriteError> for io::Error {
fn from(x: WriteError) -> Self {
use WriteError::*;
let kind = match &x {
Stopped(_) | ZeroRttRejected => io::ErrorKind::ConnectionReset,
ConnectionLost(_) | ClosedStream => io::ErrorKind::NotConnected,
#[cfg(feature = "h3")]
NotReady => io::ErrorKind::Other,
};
Self::new(kind, x)
}
}
#[derive(Debug, thiserror::Error, Clone, PartialEq, Eq)]
pub enum StoppedError {
#[error("connection lost")]
ConnectionLost(#[from] ConnectionError),
#[error("0-RTT rejected")]
ZeroRttRejected,
}
impl From<StoppedError> for io::Error {
fn from(x: StoppedError) -> Self {
use StoppedError::*;
let kind = match x {
ZeroRttRejected => io::ErrorKind::ConnectionReset,
ConnectionLost(_) => io::ErrorKind::NotConnected,
};
Self::new(kind, x)
}
}
impl AsyncWrite for SendStream {
async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
let res =
poll_fn(|cx| self.execute_poll_write(cx, |mut stream| stream.write(buf.as_init())))
.await
.map_err(Into::into);
BufResult(res, buf)
}
async fn flush(&mut self) -> io::Result<()> {
Ok(())
}
async fn shutdown(&mut self) -> io::Result<()> {
self.finish()?;
Ok(())
}
}
#[cfg(feature = "io-compat")]
mod compat {
use std::{
ops::{Deref, DerefMut},
pin::Pin,
};
use compio::buf::IntoInner;
use super::*;
pub struct CompatSendStream(pub(super) SendStream);
impl CompatSendStream {
pub async fn write(&mut self, buf: &[u8]) -> Result<usize, WriteError> {
poll_fn(|cx| self.execute_poll_write(cx, |mut stream| stream.write(buf))).await
}
pub async fn write_all(&mut self, buf: &[u8]) -> Result<(), WriteError> {
let mut count = 0;
poll_fn(|cx| {
loop {
if count == buf.len() {
return Poll::Ready(Ok(()));
}
let n = ready!(
self.execute_poll_write(cx, |mut stream| stream.write(&buf[count..]))
)?;
count += n;
}
})
.await
}
}
impl IntoInner for CompatSendStream {
type Inner = SendStream;
fn into_inner(self) -> Self::Inner {
self.0
}
}
impl Deref for CompatSendStream {
type Target = SendStream;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for CompatSendStream {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl futures_util::AsyncWrite for CompatSendStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.get_mut()
.execute_poll_write(cx, |mut stream| stream.write(buf))
.map_err(Into::into)
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.get_mut().finish()?;
Poll::Ready(Ok(()))
}
}
}
#[cfg(feature = "io-compat")]
pub use compat::CompatSendStream;
#[cfg(feature = "h3")]
pub(crate) mod h3_impl {
use compio::buf::bytes::Buf;
use h3::quic::{self, StreamErrorIncoming, WriteBuf};
use super::*;
impl From<WriteError> for StreamErrorIncoming {
fn from(e: WriteError) -> Self {
use WriteError::*;
match e {
Stopped(code) => Self::StreamTerminated {
error_code: code.into_inner(),
},
ConnectionLost(e) => Self::ConnectionErrorIncoming {
connection_error: e.into(),
},
e => Self::Unknown(Box::new(e)),
}
}
}
pub struct SendStream<B> {
inner: super::SendStream,
buf: Option<WriteBuf<B>>,
}
impl<B> SendStream<B> {
pub(crate) fn new(conn: Shared<ConnectionInner>, stream: StreamId, is_0rtt: bool) -> Self {
Self {
inner: super::SendStream::new(conn, stream, is_0rtt),
buf: None,
}
}
}
impl<B> quic::SendStream<B> for SendStream<B>
where
B: Buf,
{
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), StreamErrorIncoming>> {
if let Some(data) = &mut self.buf {
while data.has_remaining() {
let n = ready!(
self.inner
.execute_poll_write(cx, |mut stream| stream.write(data.chunk()))
)?;
data.advance(n);
}
}
self.buf = None;
Poll::Ready(Ok(()))
}
fn send_data<T: Into<WriteBuf<B>>>(&mut self, data: T) -> Result<(), StreamErrorIncoming> {
if self.buf.is_some() {
return Err(WriteError::NotReady.into());
}
self.buf = Some(data.into());
Ok(())
}
fn poll_finish(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), StreamErrorIncoming>> {
Poll::Ready(
self.inner
.finish()
.map_err(|_| WriteError::ClosedStream.into()),
)
}
fn reset(&mut self, reset_code: u64) {
self.inner
.reset(reset_code.try_into().unwrap_or(VarInt::MAX))
.ok();
}
fn send_id(&self) -> quic::StreamId {
u64::from(self.inner.stream).try_into().unwrap()
}
}
impl<B> quic::SendStreamUnframed<B> for SendStream<B>
where
B: Buf,
{
fn poll_send<D: Buf>(
&mut self,
cx: &mut Context<'_>,
buf: &mut D,
) -> Poll<Result<usize, StreamErrorIncoming>> {
debug_assert!(
self.buf.is_some(),
"poll_send called while send stream is not ready"
);
let n = ready!(
self.inner
.execute_poll_write(cx, |mut stream| stream.write(buf.chunk()))
)?;
buf.advance(n);
Poll::Ready(Ok(n))
}
}
}