use std::fmt;
use super::{
frame::MaxStreamsFrame,
varint::{VarInt, WriteVarInt, be_varint},
};
use crate::{
frame::{SendFrame, StreamsBlockedFrame},
net::tx::ArcSendWakers,
role::Role,
};
#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub enum Dir {
Bi = 0,
Uni = 1,
}
impl fmt::Display for Dir {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.pad(match *self {
Self::Bi => "bidirectional",
Self::Uni => "unidirectional",
})
}
}
#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub struct StreamId(u64);
pub const MAX_STREAMS_LIMIT: u64 = (1 << 60) - 1;
impl StreamId {
pub fn new(role: Role, dir: Dir, id: u64) -> Self {
assert!(id <= MAX_STREAMS_LIMIT);
Self((((id << 1) | (dir as u64)) << 1) | (role as u64))
}
pub fn role(&self) -> Role {
if self.0 & 0x1 == 0 {
Role::Client
} else {
Role::Server
}
}
pub fn dir(&self) -> Dir {
if self.0 & 2 == 0 { Dir::Bi } else { Dir::Uni }
}
pub fn id(&self) -> u64 {
self.0 >> 2
}
unsafe fn next_unchecked(&self) -> Self {
Self(self.0 + 4)
}
pub fn encoding_size(&self) -> usize {
VarInt::from(*self).encoding_size()
}
}
impl fmt::Display for StreamId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"{} side {} stream {}",
self.role(),
self.dir(),
self.id()
)
}
}
impl From<VarInt> for StreamId {
fn from(v: VarInt) -> Self {
Self(v.into_inner())
}
}
impl From<StreamId> for VarInt {
fn from(s: StreamId) -> Self {
VarInt::from_u64(s.0).expect("stream id must be less than VARINT_MAX")
}
}
impl From<StreamId> for u64 {
fn from(s: StreamId) -> Self {
s.0
}
}
pub fn be_streamid(input: &[u8]) -> nom::IResult<&[u8], StreamId> {
use nom::{Parser, combinator::map};
map(be_varint, StreamId::from).parse(input)
}
pub trait WriteStreamId: bytes::BufMut {
fn put_streamid(&mut self, stream_id: &StreamId);
}
impl<T: bytes::BufMut> WriteStreamId for T {
fn put_streamid(&mut self, stream_id: &StreamId) {
self.put_varint(&(*stream_id).into());
}
}
pub trait ControlStreamsConcurrency: fmt::Debug + Send + Sync {
#[must_use]
fn on_accept_streams(&mut self, dir: Dir, sid: u64) -> Option<u64>;
fn on_end_of_stream(&mut self, dir: Dir, sid: u64) -> Option<u64>;
fn on_streams_blocked(&mut self, dir: Dir, max_streams: u64) -> Option<u64>;
}
impl<C: ?Sized + ControlStreamsConcurrency> ControlStreamsConcurrency for Box<C> {
fn on_accept_streams(&mut self, dir: Dir, sid: u64) -> Option<u64> {
self.as_mut().on_accept_streams(dir, sid)
}
fn on_end_of_stream(&mut self, dir: Dir, sid: u64) -> Option<u64> {
self.as_mut().on_end_of_stream(dir, sid)
}
fn on_streams_blocked(&mut self, dir: Dir, max_streams: u64) -> Option<u64> {
self.as_mut().on_streams_blocked(dir, max_streams)
}
}
pub trait ProductStreamsConcurrencyController: Send + Sync {
fn init(
&self,
init_max_bidi_streams: u64,
init_max_uni_streams: u64,
) -> Box<dyn ControlStreamsConcurrency>;
}
impl<F, C> ProductStreamsConcurrencyController for F
where
F: Fn(u64, u64) -> C + Send + Sync,
C: ControlStreamsConcurrency + 'static,
{
#[inline]
fn init(
&self,
init_max_bidi_streams: u64,
init_max_uni_streams: u64,
) -> Box<dyn ControlStreamsConcurrency> {
Box::new((self)(init_max_bidi_streams, init_max_uni_streams))
}
}
pub mod handy;
pub mod local_sid;
pub use local_sid::ArcLocalStreamIds;
pub mod remote_sid;
pub use remote_sid::ArcRemoteStreamIds;
#[derive(Debug, Clone)]
pub struct StreamIds<BLOCKED, MAX> {
pub local: ArcLocalStreamIds<BLOCKED>,
pub remote: ArcRemoteStreamIds<MAX>,
}
impl<T> StreamIds<T, T>
where
T: SendFrame<MaxStreamsFrame> + SendFrame<StreamsBlockedFrame> + Clone + Send + 'static,
{
#[allow(clippy::too_many_arguments)]
pub fn new(
role: Role,
local_max_bi: u64,
local_max_uni: u64,
remote_max_bi: u64,
remote_max_uni: u64,
sid_frames_tx: T,
ctrl: Box<dyn ControlStreamsConcurrency>,
tx_wakers: ArcSendWakers,
) -> Self {
let local = ArcLocalStreamIds::new(
role,
remote_max_bi,
remote_max_uni,
sid_frames_tx.clone(),
tx_wakers,
);
let remote =
ArcRemoteStreamIds::new(!role, local_max_bi, local_max_uni, sid_frames_tx, ctrl);
Self { local, remote }
}
}