use std::collections::hash_map::Entry;
use std::mem;
use thiserror::Error;
use tracing::debug;
use super::{Retransmits, ShouldTransmit, StreamHalf, StreamId, StreamsState, UnknownStream};
use crate::connection::assembler::{Assembler, Chunk, IllegalOrderedRead};
use crate::{frame, Dir, TransportError, VarInt};
#[derive(Debug, Default)]
pub(super) struct Recv {
state: RecvState,
pub(super) assembler: Assembler,
sent_max_stream_data: u64,
pub(super) end: u64,
pub(super) stopped: bool,
}
impl Recv {
pub(super) fn new(initial_max_data: u64) -> Self {
Self {
state: RecvState::default(),
assembler: Assembler::new(),
sent_max_stream_data: initial_max_data,
end: 0,
stopped: false,
}
}
pub(super) fn ingest(
&mut self,
frame: frame::Stream,
payload_len: usize,
received: u64,
max_data: u64,
) -> Result<(u64, bool), TransportError> {
let end = frame.offset + frame.data.len() as u64;
if end >= 2u64.pow(62) {
return Err(TransportError::FLOW_CONTROL_ERROR(
"maximum stream offset too large",
));
}
if let Some(final_offset) = self.final_offset() {
if end > final_offset || (frame.fin && end != final_offset) {
debug!(end, final_offset, "final size error");
return Err(TransportError::FINAL_SIZE_ERROR(""));
}
}
let new_bytes = self.credit_consumed_by(end, received, max_data)?;
if frame.fin && !self.stopped {
if let RecvState::Recv { ref mut size } = self.state {
*size = Some(end);
}
}
self.end = self.end.max(end);
if !self.stopped {
self.assembler.insert(frame.offset, frame.data, payload_len);
} else {
self.assembler.set_bytes_read(end);
}
Ok((new_bytes, frame.fin && self.stopped))
}
pub(super) fn stop(&mut self) -> Result<(u64, ShouldTransmit), UnknownStream> {
if self.stopped {
return Err(UnknownStream { _private: () });
}
self.stopped = true;
self.assembler.clear();
let read_credits = self.end - self.assembler.bytes_read();
Ok((read_credits, ShouldTransmit(self.is_receiving())))
}
pub(super) fn max_stream_data(&mut self, stream_receive_window: u64) -> (u64, ShouldTransmit) {
let max_stream_data = self.assembler.bytes_read() + stream_receive_window;
let diff = max_stream_data - self.sent_max_stream_data;
let transmit = self.receiving_unknown_size() && diff >= (stream_receive_window / 8);
(max_stream_data, ShouldTransmit(transmit))
}
pub fn record_sent_max_stream_data(&mut self, sent_value: u64) {
if sent_value > self.sent_max_stream_data {
self.sent_max_stream_data = sent_value;
}
}
pub(super) fn receiving_unknown_size(&self) -> bool {
matches!(self.state, RecvState::Recv { size: None })
}
pub(super) fn is_receiving(&self) -> bool {
matches!(self.state, RecvState::Recv { .. })
}
fn final_offset(&self) -> Option<u64> {
match self.state {
RecvState::Recv { size } => size,
RecvState::ResetRecvd { size, .. } => Some(size),
}
}
pub(super) fn reset(
&mut self,
error_code: VarInt,
final_offset: VarInt,
received: u64,
max_data: u64,
) -> Result<bool, TransportError> {
if let Some(offset) = self.final_offset() {
if offset != final_offset.into() {
return Err(TransportError::FINAL_SIZE_ERROR("inconsistent value"));
}
} else if self.end > final_offset.into() {
return Err(TransportError::FINAL_SIZE_ERROR(
"lower than high water mark",
));
}
self.credit_consumed_by(final_offset.into(), received, max_data)?;
if matches!(self.state, RecvState::ResetRecvd { .. }) {
return Ok(false);
}
self.state = RecvState::ResetRecvd {
size: final_offset.into(),
error_code,
};
self.assembler.clear();
Ok(true)
}
fn credit_consumed_by(
&self,
offset: u64,
received: u64,
max_data: u64,
) -> Result<u64, TransportError> {
let prev_end = self.end;
let new_bytes = offset.saturating_sub(prev_end);
if offset > self.sent_max_stream_data || received + new_bytes > max_data {
debug!(
received,
new_bytes,
max_data,
offset,
stream_max_data = self.sent_max_stream_data,
"flow control error"
);
return Err(TransportError::FLOW_CONTROL_ERROR(""));
}
Ok(new_bytes)
}
}
pub struct Chunks<'a> {
id: StreamId,
ordered: bool,
streams: &'a mut StreamsState,
pending: &'a mut Retransmits,
state: ChunksState,
read: u64,
}
impl<'a> Chunks<'a> {
pub(super) fn new(
id: StreamId,
ordered: bool,
streams: &'a mut StreamsState,
pending: &'a mut Retransmits,
) -> Result<Self, ReadableError> {
let entry = match streams.recv.entry(id) {
Entry::Occupied(entry) => entry,
Entry::Vacant(_) => return Err(ReadableError::UnknownStream),
};
let mut recv = match entry.get().stopped {
true => return Err(ReadableError::UnknownStream),
false => entry.remove(),
};
recv.assembler.ensure_ordering(ordered)?;
Ok(Self {
id,
ordered,
streams,
pending,
state: ChunksState::Readable(recv),
read: 0,
})
}
pub fn next(&mut self, max_length: usize) -> Result<Option<Chunk>, ReadError> {
let rs = match self.state {
ChunksState::Readable(ref mut rs) => rs,
ChunksState::Reset(error_code) => {
return Err(ReadError::Reset(error_code));
}
ChunksState::Finished => {
return Ok(None);
}
ChunksState::Finalized => panic!("must not call next() after finalize()"),
};
if let Some(chunk) = rs.assembler.read(max_length, self.ordered) {
self.read += chunk.bytes.len() as u64;
return Ok(Some(chunk));
}
match rs.state {
RecvState::ResetRecvd { error_code, .. } => {
debug_assert_eq!(self.read, 0, "reset streams have empty buffers");
self.streams.stream_freed(self.id, StreamHalf::Recv);
self.state = ChunksState::Reset(error_code);
Err(ReadError::Reset(error_code))
}
RecvState::Recv { size } => {
if size == Some(rs.end) && rs.assembler.bytes_read() == rs.end {
self.streams.stream_freed(self.id, StreamHalf::Recv);
self.state = ChunksState::Finished;
Ok(None)
} else {
Err(ReadError::Blocked)
}
}
}
}
pub fn finalize(mut self) -> ShouldTransmit {
self.finalize_inner(false)
}
fn finalize_inner(&mut self, drop: bool) -> ShouldTransmit {
let state = mem::replace(&mut self.state, ChunksState::Finalized);
debug_assert!(
!drop || matches!(state, ChunksState::Finalized),
"finalize must be called before drop"
);
if let ChunksState::Finalized = state {
return ShouldTransmit(false);
}
let mut should_transmit = false;
if matches!(state, ChunksState::Finished | ChunksState::Reset(_))
&& self.streams.side != self.id.initiator()
{
match self.id.dir() {
Dir::Uni => self.pending.max_uni_stream_id = true,
Dir::Bi => self.pending.max_bi_stream_id = true,
}
should_transmit = true;
}
if let ChunksState::Readable(mut rs) = state {
let (_, max_stream_data) = rs.max_stream_data(self.streams.stream_receive_window);
should_transmit |= max_stream_data.0;
if max_stream_data.0 {
self.pending.max_stream_data.insert(self.id);
}
self.streams.recv.insert(self.id, rs);
}
let max_data = self.streams.add_read_credits(self.read);
self.pending.max_data |= max_data.0;
should_transmit |= max_data.0;
ShouldTransmit(should_transmit)
}
}
impl<'a> Drop for Chunks<'a> {
fn drop(&mut self) {
let _ = self.finalize_inner(true);
}
}
enum ChunksState {
Readable(Recv),
Reset(VarInt),
Finished,
Finalized,
}
#[derive(Debug, Error, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub enum ReadError {
#[error("blocked")]
Blocked,
#[error("reset by peer: code {0}")]
Reset(VarInt),
}
#[derive(Debug, Error, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub enum ReadableError {
#[error("unknown stream")]
UnknownStream,
#[error("ordered read after unordered read")]
IllegalOrderedRead,
}
impl From<IllegalOrderedRead> for ReadableError {
fn from(_: IllegalOrderedRead) -> Self {
ReadableError::IllegalOrderedRead
}
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
enum RecvState {
Recv { size: Option<u64> },
ResetRecvd { size: u64, error_code: VarInt },
}
impl Default for RecvState {
fn default() -> Self {
RecvState::Recv { size: None }
}
}