use std::{
sync::{Arc, Mutex},
task::{Context, Poll, Waker},
};
use super::{Dir, Role, StreamId};
use crate::{
frame::{MaxStreamsFrame, ReceiveFrame, SendFrame, StreamsBlockedFrame},
sid::MAX_STREAMS_LIMIT,
varint::VarInt,
};
#[derive(Debug)]
struct LocalStreamIds<BLOCKED> {
role: Role, max: [u64; 2], unallocated: [u64; 2], wakers: [Option<Waker>; 2], blocked: BLOCKED, }
impl<BLOCKED> LocalStreamIds<BLOCKED>
where
BLOCKED: SendFrame<StreamsBlockedFrame> + Clone + Send + 'static,
{
fn new(role: Role, max_bi_streams: u64, max_uni_streams: u64, blocked: BLOCKED) -> Self {
Self {
role,
max: [max_bi_streams, max_uni_streams],
unallocated: [0, 0],
wakers: [None, None],
blocked,
}
}
fn role(&self) -> Role {
self.role
}
fn recv_max_streams_frame(&mut self, frame: &MaxStreamsFrame) {
let (dir, val) = match frame {
MaxStreamsFrame::Bi(max) => (Dir::Bi, (*max).into_inner()),
MaxStreamsFrame::Uni(max) => (Dir::Uni, (*max).into_inner()),
};
assert!(val <= MAX_STREAMS_LIMIT);
let max_streams = &mut self.max[dir as usize];
if *max_streams < val {
*max_streams = val;
if let Some(waker) = self.wakers[dir as usize].take() {
waker.wake();
}
}
}
fn poll_alloc_sid(&mut self, cx: &mut Context<'_>, dir: Dir) -> Poll<Option<StreamId>> {
let idx = dir as usize;
let cur = &mut self.unallocated[idx];
if *cur > MAX_STREAMS_LIMIT {
Poll::Ready(None)
} else if *cur <= self.max[idx] {
let id = *cur;
*cur += 1;
Poll::Ready(Some(StreamId::new(self.role, dir, id)))
} else {
assert!(self.wakers[idx].is_none());
self.wakers[idx] = Some(cx.waker().clone());
self.blocked.send_frame([StreamsBlockedFrame::with(
dir,
VarInt::from_u64(self.max[idx])
.expect("max_streams limit must be less than VARINT_MAX"),
)]);
Poll::Pending
}
}
}
#[derive(Debug, Clone)]
pub struct ArcLocalStreamIds<BLOCKED>(Arc<Mutex<LocalStreamIds<BLOCKED>>>);
impl<BLOCKED> ArcLocalStreamIds<BLOCKED>
where
BLOCKED: SendFrame<StreamsBlockedFrame> + Clone + Send + 'static,
{
pub fn new(role: Role, max_bi_streams: u64, max_uni_streams: u64, blocked: BLOCKED) -> Self {
Self(Arc::new(Mutex::new(LocalStreamIds::new(
role,
max_bi_streams,
max_uni_streams,
blocked,
))))
}
pub fn role(&self) -> Role {
self.0.lock().unwrap().role()
}
pub fn recv_max_streams_frame(&self, frame: &MaxStreamsFrame) {
self.0.lock().unwrap().recv_max_streams_frame(frame);
}
pub fn poll_alloc_sid(&self, cx: &mut Context<'_>, dir: Dir) -> Poll<Option<StreamId>> {
self.0.lock().unwrap().poll_alloc_sid(cx, dir)
}
}
impl<BLOCKED> ReceiveFrame<MaxStreamsFrame> for ArcLocalStreamIds<BLOCKED>
where
BLOCKED: SendFrame<StreamsBlockedFrame> + Clone + Send + 'static,
{
type Output = ();
fn recv_frame(&self, frame: &MaxStreamsFrame) -> Result<Self::Output, crate::error::Error> {
self.recv_max_streams_frame(frame);
Ok(())
}
}
#[cfg(test)]
mod tests {
use deref_derive::Deref;
use super::*;
use crate::util::ArcAsyncDeque;
#[derive(Clone, Deref, Default)]
struct StreamsBlockedFrameTx(ArcAsyncDeque<StreamsBlockedFrame>);
impl SendFrame<StreamsBlockedFrame> for StreamsBlockedFrameTx {
fn send_frame<I: IntoIterator<Item = StreamsBlockedFrame>>(&self, iter: I) {
(&self.0).extend(iter);
}
}
#[test]
fn test_stream_id_new() {
let sid = StreamId::new(Role::Client, Dir::Bi, 0);
assert_eq!(sid, StreamId(0));
assert_eq!(sid.role(), Role::Client);
assert_eq!(sid.dir(), Dir::Bi);
}
#[test]
fn test_recv_max_stream_frames() {
let local = ArcLocalStreamIds::new(Role::Client, 0, 0, StreamsBlockedFrameTx::default());
local.recv_max_streams_frame(&MaxStreamsFrame::Bi(VarInt::from_u32(0)));
let waker = futures::task::noop_waker();
let mut cx = Context::from_waker(&waker);
assert_eq!(
local.poll_alloc_sid(&mut cx, Dir::Bi),
Poll::Ready(Some(StreamId(0)))
);
assert_eq!(local.poll_alloc_sid(&mut cx, Dir::Bi), Poll::Pending);
assert!(local.0.lock().unwrap().wakers[0].is_some());
local.recv_max_streams_frame(&MaxStreamsFrame::Bi(VarInt::from_u32(1)));
let _ = local.0.lock().unwrap().wakers[0].take();
assert_eq!(
local.poll_alloc_sid(&mut cx, Dir::Bi),
Poll::Ready(Some(StreamId(4)))
);
assert_eq!(local.poll_alloc_sid(&mut cx, Dir::Bi), Poll::Pending);
assert!(local.0.lock().unwrap().wakers[0].is_some());
local.recv_max_streams_frame(&MaxStreamsFrame::Uni(VarInt::from_u32(2)));
assert_eq!(
local.poll_alloc_sid(&mut cx, Dir::Uni),
Poll::Ready(Some(StreamId(2)))
);
assert_eq!(
local.poll_alloc_sid(&mut cx, Dir::Uni),
Poll::Ready(Some(StreamId(6)))
);
assert_eq!(
local.poll_alloc_sid(&mut cx, Dir::Uni),
Poll::Ready(Some(StreamId(10)))
);
assert_eq!(local.poll_alloc_sid(&mut cx, Dir::Uni), Poll::Pending);
assert!(local.0.lock().unwrap().wakers[1].is_some());
}
}