use std::{
cell::RefCell,
collections::{hash_map, BinaryHeap, VecDeque},
};
use bytes::Bytes;
use thiserror::Error;
use tracing::trace;
use super::spaces::{Retransmits, ThinRetransmits};
use crate::{frame, Dir, StreamId, VarInt};
mod recv;
use recv::Recv;
pub use recv::{Chunks, ReadError, ReadableError};
mod send;
pub use send::{ByteSlice, BytesArray, BytesSource, FinishError, WriteError, Written};
use send::{Send, SendState};
mod state;
pub use state::StreamsState;
pub struct Streams<'a> {
pub(super) state: &'a mut StreamsState,
pub(super) conn_state: &'a super::State,
}
impl<'a> Streams<'a> {
#[cfg(fuzzing)]
pub fn new(state: &'a mut StreamsState, conn_state: &'a super::State) -> Self {
Self { state, conn_state }
}
pub fn open(&mut self, dir: Dir) -> Option<StreamId> {
if self.conn_state.is_closed() {
return None;
}
if self.state.next[dir as usize] >= self.state.max[dir as usize] {
return None;
}
self.state.next[dir as usize] += 1;
let id = StreamId::new(self.state.side, dir, self.state.next[dir as usize] - 1);
self.state.insert(false, id);
self.state.send_streams += 1;
Some(id)
}
pub fn accept(&mut self, dir: Dir) -> Option<StreamId> {
if self.state.next_remote[dir as usize] == self.state.next_reported_remote[dir as usize] {
return None;
}
let x = self.state.next_reported_remote[dir as usize];
self.state.next_reported_remote[dir as usize] = x + 1;
if dir == Dir::Bi {
self.state.send_streams += 1;
}
Some(StreamId::new(!self.state.side, dir, x))
}
#[cfg(fuzzing)]
pub fn state(&mut self) -> &mut StreamsState {
self.state
}
pub fn send_streams(&self) -> usize {
self.state.send_streams
}
}
pub struct RecvStream<'a> {
pub(super) id: StreamId,
pub(super) state: &'a mut StreamsState,
pub(super) pending: &'a mut Retransmits,
}
impl<'a> RecvStream<'a> {
pub fn read(&mut self, ordered: bool) -> Result<Chunks, ReadableError> {
Chunks::new(self.id, ordered, self.state, self.pending)
}
pub fn stop(&mut self, error_code: VarInt) -> Result<(), UnknownStream> {
let mut entry = match self.state.recv.entry(self.id) {
hash_map::Entry::Occupied(s) => s,
hash_map::Entry::Vacant(_) => return Err(UnknownStream { _private: () }),
};
let stream = entry.get_mut();
let (read_credits, stop_sending) = stream.stop()?;
if stop_sending.should_transmit() {
self.pending.stop_sending.push(frame::StopSending {
id: self.id,
error_code,
});
}
if !stream.receiving_unknown_size() {
entry.remove();
self.state.stream_freed(self.id, StreamHalf::Recv);
}
if self.state.add_read_credits(read_credits).should_transmit() {
self.pending.max_data = true;
}
Ok(())
}
}
pub struct SendStream<'a> {
pub(super) id: StreamId,
pub(super) state: &'a mut StreamsState,
pub(super) pending: &'a mut Retransmits,
pub(super) conn_state: &'a super::State,
}
impl<'a> SendStream<'a> {
#[cfg(fuzzing)]
pub fn new(
id: StreamId,
state: &'a mut StreamsState,
pending: &'a mut Retransmits,
conn_state: &'a super::State,
) -> Self {
Self {
id,
state,
pending,
conn_state,
}
}
pub fn write(&mut self, data: &[u8]) -> Result<usize, WriteError> {
Ok(self.write_source(&mut ByteSlice::from_slice(data))?.bytes)
}
pub fn write_chunks(&mut self, data: &mut [Bytes]) -> Result<Written, WriteError> {
self.write_source(&mut BytesArray::from_chunks(data))
}
fn write_source<B: BytesSource>(&mut self, source: &mut B) -> Result<Written, WriteError> {
if self.conn_state.is_closed() {
trace!(%self.id, "write blocked; connection draining");
return Err(WriteError::Blocked);
}
let limit = self.state.write_limit();
let stream = self
.state
.send
.get_mut(&self.id)
.ok_or(WriteError::UnknownStream)?;
if limit == 0 {
trace!(stream = %self.id, "write blocked by connection-level flow control or send window");
if !stream.connection_blocked {
stream.connection_blocked = true;
self.state.connection_blocked.push(self.id);
}
return Err(WriteError::Blocked);
}
let was_pending = stream.is_pending();
let written = stream.write(source, limit)?;
self.state.data_sent += written.bytes as u64;
self.state.unacked_data += written.bytes as u64;
trace!(stream = %self.id, "wrote {} bytes", written.bytes);
if !was_pending {
push_pending(&mut self.state.pending, self.id, stream.priority);
}
Ok(written)
}
pub fn stopped(&mut self) -> Result<Option<VarInt>, UnknownStream> {
match self.state.send.get(&self.id) {
Some(s) => Ok(s.stop_reason),
None => Err(UnknownStream { _private: () }),
}
}
pub fn finish(&mut self) -> Result<(), FinishError> {
let stream = self
.state
.send
.get_mut(&self.id)
.ok_or(FinishError::UnknownStream)?;
let was_pending = stream.is_pending();
stream.finish()?;
if !was_pending {
push_pending(&mut self.state.pending, self.id, stream.priority);
}
Ok(())
}
pub fn reset(&mut self, error_code: VarInt) -> Result<(), UnknownStream> {
let stream = match self.state.send.get_mut(&self.id) {
Some(ss) => ss,
None => return Err(UnknownStream { _private: () }),
};
if matches!(stream.state, SendState::ResetSent) {
return Err(UnknownStream { _private: () });
}
self.state.unacked_data -= stream.pending.unacked();
stream.reset();
self.pending.reset_stream.push((self.id, error_code));
Ok(())
}
pub fn set_priority(&mut self, priority: i32) -> Result<(), UnknownStream> {
let stream = match self.state.send.get_mut(&self.id) {
Some(ss) => ss,
None => return Err(UnknownStream { _private: () }),
};
stream.priority = priority;
Ok(())
}
pub fn priority(&self) -> Result<i32, UnknownStream> {
let stream = match self.state.send.get(&self.id) {
Some(ss) => ss,
None => return Err(UnknownStream { _private: () }),
};
Ok(stream.priority)
}
}
fn push_pending(pending: &mut BinaryHeap<PendingLevel>, id: StreamId, priority: i32) {
for level in pending.iter() {
if priority == level.priority {
level.queue.borrow_mut().push_back(id);
return;
}
}
let mut queue = VecDeque::new();
queue.push_back(id);
pending.push(PendingLevel {
queue: RefCell::new(queue),
priority,
});
}
struct PendingLevel {
queue: RefCell<VecDeque<StreamId>>,
priority: i32,
}
impl PartialEq for PendingLevel {
fn eq(&self, other: &Self) -> bool {
self.priority.eq(&other.priority)
}
}
impl PartialOrd for PendingLevel {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Eq for PendingLevel {}
impl Ord for PendingLevel {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.priority.cmp(&other.priority)
}
}
#[derive(Debug)]
pub enum StreamEvent {
Opened {
dir: Dir,
},
Readable {
id: StreamId,
},
Writable {
id: StreamId,
},
Finished {
id: StreamId,
},
Stopped {
id: StreamId,
error_code: VarInt,
},
Available {
dir: Dir,
},
}
#[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
#[must_use = "A frame might need to be enqueued"]
pub struct ShouldTransmit(bool);
impl ShouldTransmit {
pub fn should_transmit(self) -> bool {
self.0
}
}
#[derive(Debug, Error, Clone, PartialEq, Eq)]
#[error("unknown stream")]
pub struct UnknownStream {
_private: (),
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
enum StreamHalf {
Send,
Recv,
}