use std::{
io,
mem::MaybeUninit,
task::{Context, Poll},
};
use compio::buf::{BufResult, IoBufMut, bytes::Bytes};
use compio::io::AsyncRead;
use futures_util::future::poll_fn;
use noq_proto::{Chunk, Chunks, ClosedStream, ReadableError, StreamId, VarInt};
use thiserror::Error;
use crate::{ConnectionError, ConnectionInner, sync::shared::Shared};
#[derive(Debug)]
pub struct RecvStream {
conn: Shared<ConnectionInner>,
stream: StreamId,
is_0rtt: bool,
all_data_read: bool,
reset: Option<VarInt>,
}
impl RecvStream {
pub(crate) fn new(conn: Shared<ConnectionInner>, stream: StreamId, is_0rtt: bool) -> Self {
Self {
conn,
stream,
is_0rtt,
all_data_read: false,
reset: None,
}
}
pub fn id(&self) -> StreamId {
self.stream
}
pub fn is_0rtt(&self) -> bool {
self.is_0rtt
}
pub fn stop(&mut self, error_code: VarInt) -> Result<(), ClosedStream> {
let mut state = self.conn.state();
if self.is_0rtt && !state.check_0rtt() {
return Ok(());
}
state.conn.recv_stream(self.stream).stop(error_code)?;
state.wake();
self.all_data_read = true;
Ok(())
}
pub async fn received_reset(&mut self) -> Result<Option<VarInt>, ResetError> {
poll_fn(|cx| {
let mut state = self.conn.state();
if self.is_0rtt && !state.check_0rtt() {
return Poll::Ready(Err(ResetError::ZeroRttRejected));
}
if let Some(code) = self.reset {
return Poll::Ready(Ok(Some(code)));
}
match state.conn.recv_stream(self.stream).received_reset() {
Err(_) => Poll::Ready(Ok(None)),
Ok(Some(error_code)) => {
state.wake();
Poll::Ready(Ok(Some(error_code)))
}
Ok(None) => {
if let Some(e) = &state.error {
return Poll::Ready(Err(e.clone().into()));
}
state.readable.insert(self.stream, cx.waker().clone());
Poll::Pending
}
}
})
.await
}
fn execute_poll_read<F, T>(
&mut self,
cx: &mut Context,
ordered: bool,
mut read_fn: F,
) -> Poll<Result<Option<T>, ReadError>>
where
F: FnMut(&mut Chunks) -> ReadStatus<T>,
{
use noq_proto::ReadError::*;
if self.all_data_read {
return Poll::Ready(Ok(None));
}
let mut state = self.conn.state();
if self.is_0rtt && !state.check_0rtt() {
return Poll::Ready(Err(ReadError::ZeroRttRejected));
}
let status = match self.reset {
Some(code) => ReadStatus::Failed(None, Reset(code)),
None => {
let mut recv = state.conn.recv_stream(self.stream);
let mut chunks = recv.read(ordered)?;
let status = read_fn(&mut chunks);
if chunks.finalize().should_transmit() {
state.wake();
}
status
}
};
match status {
ReadStatus::Readable(read) => Poll::Ready(Ok(Some(read))),
ReadStatus::Finished(read) => {
self.all_data_read = true;
Poll::Ready(Ok(read))
}
ReadStatus::Failed(read, Blocked) => match read {
Some(val) => Poll::Ready(Ok(Some(val))),
None => {
if let Some(error) = &state.error {
return Poll::Ready(Err(error.clone().into()));
}
state.readable.insert(self.stream, cx.waker().clone());
Poll::Pending
}
},
ReadStatus::Failed(read, Reset(error_code)) => match read {
None => {
self.all_data_read = true;
self.reset = Some(error_code);
Poll::Ready(Err(ReadError::Reset(error_code)))
}
done => {
self.reset = Some(error_code);
Poll::Ready(Ok(done))
}
},
}
}
pub(crate) fn poll_read_impl(
&mut self,
cx: &mut Context,
buf: &mut [MaybeUninit<u8>],
) -> Poll<Result<Option<usize>, ReadError>> {
if buf.is_empty() {
return Poll::Ready(Ok(Some(0)));
}
self.execute_poll_read(cx, true, |chunks| {
let mut read = 0;
loop {
if read >= buf.len() {
return ReadStatus::Readable(read);
}
match chunks.next(buf.len() - read) {
Ok(Some(chunk)) => {
let bytes = chunk.bytes;
let len = bytes.len();
buf[read..read + len].copy_from_slice(unsafe {
std::slice::from_raw_parts(bytes.as_ptr().cast(), len)
});
read += len;
}
res => {
return (if read == 0 { None } else { Some(read) }, res.err()).into();
}
}
}
})
}
pub fn poll_read_uninit(
&mut self,
cx: &mut Context,
buf: &mut [MaybeUninit<u8>],
) -> Poll<Result<usize, ReadError>> {
self.poll_read_impl(cx, buf)
.map(|res| res.map(|n| n.unwrap_or_default()))
}
pub async fn read_chunk(
&mut self,
max_length: usize,
ordered: bool,
) -> Result<Option<Chunk>, ReadError> {
poll_fn(|cx| {
self.execute_poll_read(cx, ordered, |chunks| match chunks.next(max_length) {
Ok(Some(chunk)) => ReadStatus::Readable(chunk),
res => (None, res.err()).into(),
})
})
.await
}
pub async fn read_chunks(&mut self, bufs: &mut [Bytes]) -> Result<Option<usize>, ReadError> {
if bufs.is_empty() {
return Ok(Some(0));
}
poll_fn(|cx| {
self.execute_poll_read(cx, true, |chunks| {
let mut read = 0;
loop {
if read >= bufs.len() {
return ReadStatus::Readable(read);
}
match chunks.next(usize::MAX) {
Ok(Some(chunk)) => {
bufs[read] = chunk.bytes;
read += 1;
}
res => {
return (if read == 0 { None } else { Some(read) }, res.err()).into();
}
}
}
})
})
.await
}
pub async fn read_to_end<B: IoBufMut>(&mut self, mut buf: B) -> BufResult<usize, B> {
let mut start = u64::MAX;
let mut end = 0;
let mut chunks = vec![];
loop {
let chunk = match self.read_chunk(usize::MAX, false).await {
Ok(Some(chunk)) => chunk,
Ok(None) => break,
Err(e) => return BufResult(Err(e.into()), buf),
};
start = start.min(chunk.offset);
end = end.max(chunk.offset + chunk.bytes.len() as u64);
chunks.push((chunk.offset, chunk.bytes));
}
if start == u64::MAX || start >= end {
return BufResult(Ok(0), buf);
}
let len = (end - start) as usize;
let cap = buf.buf_capacity();
let needed = len.saturating_sub(cap);
if needed > 0
&& let Err(e) = buf.reserve(needed)
{
return BufResult(Err(io::Error::new(io::ErrorKind::OutOfMemory, e)), buf);
}
let slice = &mut buf.as_uninit()[..len];
slice.fill(MaybeUninit::new(0));
for (offset, bytes) in chunks {
let offset = (offset - start) as usize;
let buf_len = bytes.len();
slice[offset..offset + buf_len].copy_from_slice(unsafe {
std::slice::from_raw_parts(bytes.as_ptr().cast(), buf_len)
});
}
unsafe { buf.advance_to(len) }
BufResult(Ok(len), buf)
}
#[cfg(feature = "io-compat")]
pub fn into_compat(self) -> CompatRecvStream {
CompatRecvStream(self)
}
}
impl Drop for RecvStream {
fn drop(&mut self) {
let mut state = self.conn.state();
state.readable.remove(&self.stream);
if state.error.is_some() || (self.is_0rtt && !state.check_0rtt()) {
return;
}
if !self.all_data_read {
let _ = state.conn.recv_stream(self.stream).stop(0u32.into());
state.wake();
}
}
}
enum ReadStatus<T> {
Readable(T),
Finished(Option<T>),
Failed(Option<T>, noq_proto::ReadError),
}
impl<T> From<(Option<T>, Option<noq_proto::ReadError>)> for ReadStatus<T> {
fn from(status: (Option<T>, Option<noq_proto::ReadError>)) -> Self {
match status {
(read, None) => Self::Finished(read),
(read, Some(e)) => Self::Failed(read, e),
}
}
}
#[derive(Debug, Error, Clone, PartialEq, Eq)]
pub enum ReadError {
#[error("stream reset by peer: error {0}")]
Reset(VarInt),
#[error("connection lost")]
ConnectionLost(#[from] ConnectionError),
#[error("closed stream")]
ClosedStream,
#[error("ordered read after unordered read")]
IllegalOrderedRead,
#[error("0-RTT rejected")]
ZeroRttRejected,
}
impl From<ReadableError> for ReadError {
fn from(e: ReadableError) -> Self {
match e {
ReadableError::ClosedStream => Self::ClosedStream,
ReadableError::IllegalOrderedRead => Self::IllegalOrderedRead,
}
}
}
impl From<ResetError> for ReadError {
fn from(e: ResetError) -> Self {
match e {
ResetError::ConnectionLost(e) => Self::ConnectionLost(e),
ResetError::ZeroRttRejected => Self::ZeroRttRejected,
}
}
}
impl From<ReadError> for io::Error {
fn from(x: ReadError) -> Self {
use self::ReadError::*;
let kind = match x {
Reset { .. } | ZeroRttRejected => io::ErrorKind::ConnectionReset,
ConnectionLost(_) | ClosedStream => io::ErrorKind::NotConnected,
IllegalOrderedRead => io::ErrorKind::InvalidInput,
};
Self::new(kind, x)
}
}
#[derive(Debug, Error, Clone, PartialEq, Eq)]
pub enum ReadExactError {
#[error("stream finished early (expected {0} bytes more)")]
FinishedEarly(usize),
#[error(transparent)]
ReadError(#[from] ReadError),
}
#[derive(Debug, Error, Clone, PartialEq, Eq)]
pub enum ResetError {
#[error("connection lost")]
ConnectionLost(#[from] ConnectionError),
#[error("0-RTT rejected")]
ZeroRttRejected,
}
impl From<ResetError> for io::Error {
fn from(x: ResetError) -> Self {
use ResetError::*;
let kind = match x {
ZeroRttRejected => io::ErrorKind::ConnectionReset,
ConnectionLost(_) => io::ErrorKind::NotConnected,
};
Self::new(kind, x)
}
}
impl AsyncRead for RecvStream {
async fn read<B: IoBufMut>(&mut self, mut buf: B) -> BufResult<usize, B> {
let res = poll_fn(|cx| self.poll_read_uninit(cx, buf.as_uninit()))
.await
.inspect(|&n| unsafe { buf.advance_to(n) })
.map_err(Into::into);
BufResult(res, buf)
}
}
#[cfg(feature = "io-compat")]
mod compat {
use std::{
ops::{Deref, DerefMut},
pin::Pin,
task::ready,
};
use compio::buf::{IntoInner, bytes::BufMut};
use super::*;
pub struct CompatRecvStream(pub(super) RecvStream);
impl CompatRecvStream {
fn poll_read(
&mut self,
cx: &mut Context,
mut buf: impl BufMut,
) -> Poll<Result<Option<usize>, ReadError>> {
self.poll_read_impl(cx, unsafe { buf.chunk_mut().as_uninit_slice_mut() })
.map(|res| {
if let Ok(Some(n)) = &res {
unsafe { buf.advance_mut(*n) }
}
res
})
}
pub async fn read(&mut self, mut buf: impl BufMut) -> Result<Option<usize>, ReadError> {
poll_fn(|cx| self.poll_read(cx, &mut buf)).await
}
pub async fn read_exact(&mut self, mut buf: impl BufMut) -> Result<(), ReadExactError> {
poll_fn(|cx| {
while buf.has_remaining_mut() {
if ready!(self.poll_read(cx, &mut buf))?.is_none() {
return Poll::Ready(Err(ReadExactError::FinishedEarly(
buf.remaining_mut(),
)));
}
}
Poll::Ready(Ok(()))
})
.await
}
}
impl IntoInner for CompatRecvStream {
type Inner = RecvStream;
fn into_inner(self) -> Self::Inner {
self.0
}
}
impl Deref for CompatRecvStream {
type Target = RecvStream;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for CompatRecvStream {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl futures_util::AsyncRead for CompatRecvStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
self.get_mut()
.poll_read_uninit(cx, unsafe {
std::slice::from_raw_parts_mut(buf.as_mut_ptr().cast(), buf.len())
})
.map_err(Into::into)
}
}
}
#[cfg(feature = "io-compat")]
pub use compat::CompatRecvStream;
#[cfg(feature = "h3")]
pub(crate) mod h3_impl {
use h3::quic::{self, StreamErrorIncoming};
use super::*;
impl From<ReadError> for StreamErrorIncoming {
fn from(e: ReadError) -> Self {
use ReadError::*;
match e {
Reset(code) => Self::StreamTerminated {
error_code: code.into_inner(),
},
ConnectionLost(e) => Self::ConnectionErrorIncoming {
connection_error: e.into(),
},
IllegalOrderedRead => unreachable!("illegal ordered read"),
e => Self::Unknown(Box::new(e)),
}
}
}
impl quic::RecvStream for RecvStream {
type Buf = Bytes;
fn poll_data(
&mut self,
cx: &mut Context<'_>,
) -> Poll<Result<Option<Self::Buf>, StreamErrorIncoming>> {
self.execute_poll_read(cx, true, |chunks| match chunks.next(usize::MAX) {
Ok(Some(chunk)) => ReadStatus::Readable(chunk.bytes),
res => (None, res.err()).into(),
})
.map_err(Into::into)
}
fn stop_sending(&mut self, error_code: u64) {
self.stop(error_code.try_into().expect("invalid error_code"))
.ok();
}
fn recv_id(&self) -> quic::StreamId {
u64::from(self.stream).try_into().unwrap()
}
}
}