use std::{
collections::VecDeque,
sync::{Arc, Mutex},
task::{Context, Poll, Waker},
};
use super::{Dir, Role, StreamId};
use crate::{
frame::{
MaxStreamsFrame, StreamsBlockedFrame,
io::{ReceiveFrame, SendFrame},
},
net::tx::{ArcSendWakers, Signals},
sid::MAX_STREAMS_LIMIT,
varint::VarInt,
};
#[derive(Debug)]
struct LocalStreamIds<BLOCKED> {
role: Role,
max: [u64; 2],
unallocated: [u64; 2],
wakers: [VecDeque<Waker>; 2],
blocked: BLOCKED,
tx_wakers: ArcSendWakers,
}
impl<BLOCKED> LocalStreamIds<BLOCKED>
where
BLOCKED: SendFrame<StreamsBlockedFrame> + Clone + Send + 'static,
{
fn new(
role: Role,
init_max_bi_streams: u64,
init_max_uni_streams: u64,
blocked: BLOCKED,
tx_wakers: ArcSendWakers,
) -> Self {
debug_assert!(
role == Role::Client || (init_max_bi_streams == 0 && init_max_uni_streams == 0),
"Server cannot remember the parameters"
);
Self {
role,
max: [init_max_bi_streams, init_max_uni_streams],
unallocated: [0, 0],
wakers: [VecDeque::with_capacity(2), VecDeque::with_capacity(2)],
blocked,
tx_wakers,
}
}
fn role(&self) -> Role {
self.role
}
fn opened_streams(&self, dir: Dir) -> u64 {
self.unallocated[dir as usize]
}
fn recv_max_streams_frame(&mut self, frame: &MaxStreamsFrame) {
let (dir, val) = match frame {
MaxStreamsFrame::Bi(max) => (Dir::Bi, (*max).into_inner()),
MaxStreamsFrame::Uni(max) => (Dir::Uni, (*max).into_inner()),
};
self.increase_limit(dir, val);
}
fn increase_limit(&mut self, dir: Dir, val: u64) {
assert!(val <= MAX_STREAMS_LIMIT);
let max_streams = &mut self.max[dir as usize];
if *max_streams < val {
if *max_streams < self.unallocated[dir as usize] {
self.tx_wakers.wake_all_by(Signals::WRITTEN);
}
for waker in self.wakers[dir as usize].drain(..) {
waker.wake();
}
*max_streams = val;
}
}
fn poll_alloc_sid(&mut self, cx: &mut Context<'_>, dir: Dir) -> Poll<Option<StreamId>> {
let idx = dir as usize;
let max = self.max[idx];
let unallocated = self.unallocated[idx];
if unallocated > MAX_STREAMS_LIMIT {
Poll::Ready(None)
} else if unallocated < max {
self.unallocated[idx] += 1;
Poll::Ready(Some(StreamId::new(self.role, dir, unallocated)))
} else {
self.wakers[idx].push_back(cx.waker().clone());
self.blocked.send_frame([StreamsBlockedFrame::with(
dir,
VarInt::from_u64(max).expect("max_streams limit must be less than VARINT_MAX"),
)]);
Poll::Pending
}
}
pub fn revise_max_streams(
&mut self,
zero_rtt_rejected: bool,
max_stream_bidi: u64,
max_stream_uni: u64,
) {
if zero_rtt_rejected {
self.max = [0, 0];
}
self.increase_limit(Dir::Bi, max_stream_bidi);
self.increase_limit(Dir::Uni, max_stream_uni);
}
}
#[derive(Debug, Clone)]
pub struct ArcLocalStreamIds<BLOCKED>(Arc<Mutex<LocalStreamIds<BLOCKED>>>);
impl<BLOCKED> ArcLocalStreamIds<BLOCKED>
where
BLOCKED: SendFrame<StreamsBlockedFrame> + Clone + Send + 'static,
{
pub fn new(
role: Role,
max_bidi: u64,
max_uni: u64,
blocked: BLOCKED,
tx_wakers: ArcSendWakers,
) -> Self {
Self(Arc::new(Mutex::new(LocalStreamIds::new(
role, max_bidi, max_uni, blocked, tx_wakers,
))))
}
pub fn role(&self) -> Role {
self.0.lock().unwrap().role()
}
pub fn opened_streams(&self, dir: Dir) -> u64 {
self.0.lock().unwrap().opened_streams(dir)
}
pub fn recv_max_streams_frame(&self, frame: &MaxStreamsFrame) {
self.0.lock().unwrap().recv_max_streams_frame(frame);
}
pub fn poll_alloc_sid(&self, cx: &mut Context<'_>, dir: Dir) -> Poll<Option<StreamId>> {
self.0.lock().unwrap().poll_alloc_sid(cx, dir)
}
pub fn revise_max_streams(
&self,
zero_rtt_rejected: bool,
max_stream_bidi: u64,
max_stream_uni: u64,
) {
self.0.lock().unwrap().revise_max_streams(
zero_rtt_rejected,
max_stream_bidi,
max_stream_uni,
);
}
}
impl<BLOCKED> ReceiveFrame<MaxStreamsFrame> for ArcLocalStreamIds<BLOCKED>
where
BLOCKED: SendFrame<StreamsBlockedFrame> + Clone + Send + 'static,
{
type Output = ();
fn recv_frame(&self, frame: &MaxStreamsFrame) -> Result<Self::Output, crate::error::Error> {
self.recv_max_streams_frame(frame);
Ok(())
}
}
#[cfg(test)]
mod tests {
use derive_more::Deref;
use super::*;
use crate::util::ArcAsyncDeque;
#[derive(Clone, Deref, Default)]
struct StreamsBlockedFrameTx(ArcAsyncDeque<StreamsBlockedFrame>);
impl SendFrame<StreamsBlockedFrame> for StreamsBlockedFrameTx {
fn send_frame<I: IntoIterator<Item = StreamsBlockedFrame>>(&self, iter: I) {
(&self.0).extend(iter);
}
}
#[test]
fn test_stream_id_new() {
let sid = StreamId::new(Role::Client, Dir::Bi, 0);
assert_eq!(sid, StreamId(0));
assert_eq!(sid.role(), Role::Client);
assert_eq!(sid.dir(), Dir::Bi);
}
#[test]
fn test_recv_max_stream_frames() {
let local = ArcLocalStreamIds::new(
Role::Client,
0,
0,
StreamsBlockedFrameTx::default(),
ArcSendWakers::default(),
);
local.recv_max_streams_frame(&MaxStreamsFrame::Bi(VarInt::from_u32(0)));
let waker = futures::task::noop_waker();
let mut cx = Context::from_waker(&waker);
assert_eq!(local.poll_alloc_sid(&mut cx, Dir::Bi), Poll::Pending,);
assert!(!local.0.lock().unwrap().wakers[0].is_empty());
local.recv_max_streams_frame(&MaxStreamsFrame::Bi(VarInt::from_u32(1)));
let _ = local.0.lock().unwrap().wakers[0].pop_front();
assert_eq!(
local.poll_alloc_sid(&mut cx, Dir::Bi),
Poll::Ready(Some(StreamId(0)))
);
assert_eq!(local.poll_alloc_sid(&mut cx, Dir::Bi), Poll::Pending);
assert!(!local.0.lock().unwrap().wakers[0].is_empty());
local.recv_max_streams_frame(&MaxStreamsFrame::Uni(VarInt::from_u32(2)));
assert_eq!(
local.poll_alloc_sid(&mut cx, Dir::Uni),
Poll::Ready(Some(StreamId(2)))
);
assert_eq!(
local.poll_alloc_sid(&mut cx, Dir::Uni),
Poll::Ready(Some(StreamId(6)))
);
assert_eq!(local.poll_alloc_sid(&mut cx, Dir::Uni), Poll::Pending);
assert!(!local.0.lock().unwrap().wakers[1].is_empty());
}
}