use std::collections::hash_map::Entry;
use std::mem;
use thiserror::Error;
use tracing::debug;
use super::state::get_or_insert_recv;
use super::{ClosedStream, Retransmits, ShouldTransmit, StreamId, StreamsState};
use crate::connection::assembler::{Assembler, Chunk, IllegalOrderedRead};
use crate::connection::streams::state::StreamRecv;
use crate::{TransportError, VarInt, frame};
#[derive(Debug, Default)]
pub(super) struct Recv {
state: RecvState,
pub(super) assembler: Assembler,
sent_max_stream_data: u64,
pub(super) end: u64,
pub(super) stopped: bool,
}
impl Recv {
pub(super) fn new(initial_max_data: u64) -> Box<Self> {
Box::new(Self {
state: RecvState::default(),
assembler: Assembler::new(),
sent_max_stream_data: initial_max_data,
end: 0,
stopped: false,
})
}
pub(super) fn reinit(&mut self, initial_max_data: u64) {
self.state = RecvState::default();
self.assembler.reinit();
self.sent_max_stream_data = initial_max_data;
self.end = 0;
self.stopped = false;
}
pub(super) fn ingest(
&mut self,
frame: frame::Stream,
payload_len: usize,
received: u64,
max_data: u64,
) -> Result<(u64, bool), TransportError> {
let end = frame.offset + frame.data.len() as u64;
if end >= 2u64.pow(62) {
return Err(TransportError::FLOW_CONTROL_ERROR(
"maximum stream offset too large",
));
}
if let Some(final_offset) = self.final_offset() {
if end > final_offset || (frame.fin && end != final_offset) {
debug!(end, final_offset, "final size error");
return Err(TransportError::FINAL_SIZE_ERROR(""));
}
}
let new_bytes = self.credit_consumed_by(end, received, max_data)?;
if frame.fin && !self.stopped {
if let RecvState::Recv { ref mut size } = self.state {
*size = Some(end);
}
}
self.end = self.end.max(end);
if !self.stopped {
self.assembler.insert(frame.offset, frame.data, payload_len);
}
Ok((new_bytes, frame.fin && self.stopped))
}
pub(super) fn stop(&mut self) -> Result<(u64, ShouldTransmit), ClosedStream> {
if self.stopped {
return Err(ClosedStream { _private: () });
}
self.stopped = true;
self.assembler.clear();
let read_credits = self.end - self.assembler.bytes_read();
Ok((read_credits, ShouldTransmit(self.is_receiving())))
}
pub(super) fn max_stream_data(&mut self, stream_receive_window: u64) -> (u64, ShouldTransmit) {
let max_stream_data = self.assembler.bytes_read() + stream_receive_window;
let diff = max_stream_data - self.sent_max_stream_data;
let transmit = self.can_send_flow_control() && diff >= (stream_receive_window / 8);
(max_stream_data, ShouldTransmit(transmit))
}
pub(super) fn record_sent_max_stream_data(&mut self, sent_value: u64) {
if sent_value > self.sent_max_stream_data {
self.sent_max_stream_data = sent_value;
}
}
pub(super) fn final_offset_unknown(&self) -> bool {
matches!(self.state, RecvState::Recv { size: None })
}
pub(super) fn can_send_flow_control(&self) -> bool {
self.final_offset_unknown() && !self.stopped
}
pub(super) fn is_receiving(&self) -> bool {
matches!(self.state, RecvState::Recv { .. })
}
fn final_offset(&self) -> Option<u64> {
match self.state {
RecvState::Recv { size } => size,
RecvState::ResetRecvd { size, .. } => Some(size),
}
}
pub(super) fn reset(
&mut self,
error_code: VarInt,
final_offset: VarInt,
received: u64,
max_data: u64,
) -> Result<bool, TransportError> {
if let Some(offset) = self.final_offset() {
if offset != final_offset.into_inner() {
return Err(TransportError::FINAL_SIZE_ERROR("inconsistent value"));
}
} else if self.end > u64::from(final_offset) {
return Err(TransportError::FINAL_SIZE_ERROR(
"lower than high water mark",
));
}
self.credit_consumed_by(final_offset.into(), received, max_data)?;
if matches!(self.state, RecvState::ResetRecvd { .. }) {
return Ok(false);
}
self.state = RecvState::ResetRecvd {
size: final_offset.into(),
error_code,
};
self.assembler.clear();
Ok(true)
}
pub(super) fn reset_code(&self) -> Option<VarInt> {
match self.state {
RecvState::ResetRecvd { error_code, .. } => Some(error_code),
_ => None,
}
}
fn credit_consumed_by(
&self,
offset: u64,
received: u64,
max_data: u64,
) -> Result<u64, TransportError> {
let prev_end = self.end;
let new_bytes = offset.saturating_sub(prev_end);
if offset > self.sent_max_stream_data || received + new_bytes > max_data {
debug!(
received,
new_bytes,
max_data,
offset,
stream_max_data = self.sent_max_stream_data,
"flow control error"
);
return Err(TransportError::FLOW_CONTROL_ERROR(""));
}
Ok(new_bytes)
}
}
pub struct Chunks<'a> {
id: StreamId,
ordered: bool,
streams: &'a mut StreamsState,
pending: &'a mut Retransmits,
state: ChunksState,
read: u64,
}
impl<'a> Chunks<'a> {
pub(super) fn new(
id: StreamId,
ordered: bool,
streams: &'a mut StreamsState,
pending: &'a mut Retransmits,
) -> Result<Self, ReadableError> {
let mut entry = match streams.recv.entry(id) {
Entry::Occupied(entry) => entry,
Entry::Vacant(_) => return Err(ReadableError::ClosedStream),
};
let mut recv =
match get_or_insert_recv(streams.stream_receive_window)(entry.get_mut()).stopped {
true => return Err(ReadableError::ClosedStream),
false => entry.remove().unwrap().into_inner(), };
recv.assembler.ensure_ordering(ordered)?;
Ok(Self {
id,
ordered,
streams,
pending,
state: ChunksState::Readable(recv),
read: 0,
})
}
pub fn next(&mut self, max_length: usize) -> Result<Option<Chunk>, ReadError> {
let rs = match self.state {
ChunksState::Readable(ref mut rs) => rs,
ChunksState::Reset(error_code) => {
return Err(ReadError::Reset(error_code));
}
ChunksState::Finished => {
return Ok(None);
}
ChunksState::Finalized => panic!("must not call next() after finalize()"),
};
if let Some(chunk) = rs.assembler.read(max_length, self.ordered) {
self.read += chunk.bytes.len() as u64;
return Ok(Some(chunk));
}
match rs.state {
RecvState::ResetRecvd { error_code, .. } => {
debug_assert_eq!(self.read, 0, "reset streams have empty buffers");
let state = mem::replace(&mut self.state, ChunksState::Reset(error_code));
let recv = match state {
ChunksState::Readable(recv) => StreamRecv::Open(recv),
_ => unreachable!("state must be ChunkState::Readable"),
};
self.streams.stream_recv_freed(self.id, recv);
Err(ReadError::Reset(error_code))
}
RecvState::Recv { size } => {
if size == Some(rs.end) && rs.assembler.bytes_read() == rs.end {
let state = mem::replace(&mut self.state, ChunksState::Finished);
let recv = match state {
ChunksState::Readable(recv) => StreamRecv::Open(recv),
_ => unreachable!("state must be ChunkState::Readable"),
};
self.streams.stream_recv_freed(self.id, recv);
Ok(None)
} else {
Err(ReadError::Blocked)
}
}
}
}
pub fn finalize(mut self) -> ShouldTransmit {
self.finalize_inner()
}
fn finalize_inner(&mut self) -> ShouldTransmit {
let state = mem::replace(&mut self.state, ChunksState::Finalized);
if let ChunksState::Finalized = state {
return ShouldTransmit(false);
}
let mut should_transmit = self.streams.queue_max_stream_id(self.pending);
if let ChunksState::Readable(mut rs) = state {
let (_, max_stream_data) = rs.max_stream_data(self.streams.stream_receive_window);
should_transmit |= max_stream_data.0;
if max_stream_data.0 {
self.pending.max_stream_data.insert(self.id);
}
self.streams
.recv
.insert(self.id, Some(StreamRecv::Open(rs)));
}
let max_data = self.streams.add_read_credits(self.read);
self.pending.max_data |= max_data.0;
should_transmit |= max_data.0;
ShouldTransmit(should_transmit)
}
}
impl Drop for Chunks<'_> {
fn drop(&mut self) {
let _ = self.finalize_inner();
}
}
enum ChunksState {
Readable(Box<Recv>),
Reset(VarInt),
Finished,
Finalized,
}
#[derive(Debug, Error, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub enum ReadError {
#[error("blocked")]
Blocked,
#[error("reset by peer: code {0}")]
Reset(VarInt),
#[error("stream closed due to connection error")]
ConnectionClosed,
}
#[derive(Debug, Error, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub enum ReadableError {
#[error("closed stream")]
ClosedStream,
#[error("ordered read after unordered read")]
IllegalOrderedRead,
#[error("stream closed due to connection error")]
ConnectionClosed,
}
impl From<IllegalOrderedRead> for ReadableError {
fn from(_: IllegalOrderedRead) -> Self {
Self::IllegalOrderedRead
}
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
enum RecvState {
Recv { size: Option<u64> },
ResetRecvd { size: u64, error_code: VarInt },
}
impl Default for RecvState {
fn default() -> Self {
Self::Recv { size: None }
}
}
#[cfg(test)]
mod tests {
use bytes::Bytes;
use crate::{Dir, Side};
use super::*;
#[test]
fn reordered_frames_while_stopped() {
const INITIAL_BYTES: u64 = 3;
const INITIAL_OFFSET: u64 = 3;
const RECV_WINDOW: u64 = 8;
let mut s = Recv::new(RECV_WINDOW);
let mut data_recvd = 0;
let (new_bytes, is_closed) = s
.ingest(
frame::Stream {
id: StreamId::new(Side::Client, Dir::Uni, 0),
offset: INITIAL_OFFSET,
fin: false,
data: Bytes::from_static(&[0; INITIAL_BYTES as usize]),
},
123,
data_recvd,
data_recvd + 1024,
)
.unwrap();
data_recvd += new_bytes;
assert_eq!(new_bytes, INITIAL_OFFSET + INITIAL_BYTES);
assert!(!is_closed);
let (credits, transmit) = s.stop().unwrap();
assert!(transmit.should_transmit());
assert_eq!(
credits,
INITIAL_OFFSET + INITIAL_BYTES,
"full connection flow control credit is issued by stop"
);
let (max_stream_data, transmit) = s.max_stream_data(RECV_WINDOW);
assert!(!transmit.should_transmit());
assert_eq!(
max_stream_data, RECV_WINDOW,
"stream flow control credit isn't issued by stop"
);
let (new_bytes, is_closed) = s
.ingest(
frame::Stream {
id: StreamId::new(Side::Client, Dir::Uni, 0),
offset: RECV_WINDOW - 1,
fin: false,
data: Bytes::from_static(&[0; 1]),
},
123,
data_recvd,
data_recvd + 1024,
)
.unwrap();
data_recvd += new_bytes;
assert_eq!(new_bytes, RECV_WINDOW - (INITIAL_OFFSET + INITIAL_BYTES));
assert!(!is_closed);
let (max_stream_data, transmit) = s.max_stream_data(RECV_WINDOW);
assert!(!transmit.should_transmit());
assert_eq!(
max_stream_data, RECV_WINDOW,
"stream flow control credit isn't issued after stop"
);
let (new_bytes, is_closed) = s
.ingest(
frame::Stream {
id: StreamId::new(Side::Client, Dir::Uni, 0),
offset: 0,
fin: false,
data: Bytes::from_static(&[0; INITIAL_OFFSET as usize]),
},
123,
data_recvd,
data_recvd + 1024,
)
.unwrap();
assert_eq!(
new_bytes, 0,
"reordered frames don't issue connection-level flow control for stopped streams"
);
assert!(!is_closed);
let (max_stream_data, transmit) = s.max_stream_data(RECV_WINDOW);
assert!(!transmit.should_transmit());
assert_eq!(
max_stream_data, RECV_WINDOW,
"stream flow control credit isn't issued after stop"
);
}
}