use std::{fmt, ops};
use super::{
frame::MaxStreamsFrame,
varint::{be_varint, VarInt, WriteVarInt},
};
use crate::frame::{SendFrame, StreamsBlockedFrame};
#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub enum Role {
Client = 0,
Server = 1,
}
impl fmt::Display for Role {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.pad(match *self {
Self::Client => "client",
Self::Server => "server",
})
}
}
impl ops::Not for Role {
type Output = Self;
fn not(self) -> Self {
match self {
Self::Client => Self::Server,
Self::Server => Self::Client,
}
}
}
#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub enum Dir {
Bi = 0,
Uni = 1,
}
impl fmt::Display for Dir {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.pad(match *self {
Self::Bi => "bidirectional",
Self::Uni => "unidirectional",
})
}
}
#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
pub struct StreamId(u64);
pub const MAX_STREAMS_LIMIT: u64 = (1 << 60) - 1;
impl StreamId {
fn new(role: Role, dir: Dir, id: u64) -> Self {
assert!(id <= MAX_STREAMS_LIMIT);
Self((((id << 1) | (dir as u64)) << 1) | (role as u64))
}
pub fn role(&self) -> Role {
if self.0 & 0x1 == 0 {
Role::Client
} else {
Role::Server
}
}
pub fn dir(&self) -> Dir {
if self.0 & 2 == 0 {
Dir::Bi
} else {
Dir::Uni
}
}
pub fn id(&self) -> u64 {
self.0 >> 2
}
unsafe fn next_unchecked(&self) -> Self {
Self(self.0 + 4)
}
pub fn encoding_size(&self) -> usize {
VarInt::from(*self).encoding_size()
}
}
impl fmt::Display for StreamId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"{} side {} stream {}",
self.role(),
self.dir(),
self.id()
)
}
}
impl From<VarInt> for StreamId {
fn from(v: VarInt) -> Self {
Self(v.into_inner())
}
}
impl From<StreamId> for VarInt {
fn from(s: StreamId) -> Self {
VarInt::from_u64(s.0).expect("stream id must be less than VARINT_MAX")
}
}
impl From<StreamId> for u64 {
fn from(s: StreamId) -> Self {
s.0
}
}
pub fn be_streamid(input: &[u8]) -> nom::IResult<&[u8], StreamId> {
use nom::combinator::map;
map(be_varint, StreamId::from)(input)
}
pub trait WriteStreamId: bytes::BufMut {
fn put_streamid(&mut self, stream_id: &StreamId);
}
impl<T: bytes::BufMut> WriteStreamId for T {
fn put_streamid(&mut self, stream_id: &StreamId) {
self.put_varint(&(*stream_id).into());
}
}
pub trait ControlConcurrency: fmt::Debug + Send + Sync {
#[must_use]
fn on_accept_streams(&mut self, dir: Dir, sid: u64) -> Option<u64>;
fn on_end_of_stream(&mut self, dir: Dir, sid: u64) -> Option<u64>;
fn on_streams_blocked(&mut self, dir: Dir, max_streams: u64) -> Option<u64>;
}
pub mod handy;
pub mod local_sid;
pub use local_sid::ArcLocalStreamIds;
pub mod remote_sid;
pub use remote_sid::ArcRemoteStreamIds;
#[derive(Debug, Clone)]
pub struct StreamIds<BLOCKED, MAX> {
pub local: ArcLocalStreamIds<BLOCKED>,
pub remote: ArcRemoteStreamIds<MAX>,
}
impl<T> StreamIds<T, T>
where
T: SendFrame<MaxStreamsFrame> + SendFrame<StreamsBlockedFrame> + Clone + Send + 'static,
{
pub fn new(
role: Role,
max_bi_streams: u64,
max_uni_streams: u64,
sid_frames_tx: T,
ctrl: Box<dyn ControlConcurrency>,
) -> Self {
let local = ArcLocalStreamIds::new(role, 0, 0, sid_frames_tx.clone());
let remote =
ArcRemoteStreamIds::new(!role, max_bi_streams, max_uni_streams, sid_frames_tx, ctrl);
Self { local, remote }
}
}