use std::sync::{Arc, Mutex};
use thiserror::Error;
use super::{ControlConcurrency, Dir, Role, StreamId};
use crate::{
frame::{MaxStreamsFrame, ReceiveFrame, SendFrame, StreamsBlockedFrame},
varint::VarInt,
};
#[derive(Debug, PartialEq, Error)]
#[error("{0} exceed limit: {1}")]
pub struct ExceedLimitError(StreamId, u64);
#[derive(Debug, PartialEq)]
pub enum AcceptSid {
Old,
New(NeedCreate),
}
#[derive(Debug, PartialEq)]
pub struct NeedCreate {
start: StreamId,
end: StreamId,
}
impl Iterator for NeedCreate {
type Item = StreamId;
fn next(&mut self) -> Option<Self::Item> {
if self.start > self.end {
None
} else {
let id = self.start;
self.start = unsafe { self.start.next_unchecked() };
Some(id)
}
}
}
#[derive(Debug)]
struct RemoteStreamIds<MAX> {
role: Role, max: [u64; 2], unallocated: [StreamId; 2], ctrl: Box<dyn ControlConcurrency>, max_tx: MAX, }
impl<MAX> RemoteStreamIds<MAX>
where
MAX: SendFrame<MaxStreamsFrame> + Clone + Send + 'static,
{
fn new(
role: Role,
max_bi: u64,
max_uni: u64,
max_tx: MAX,
ctrl: Box<dyn ControlConcurrency>,
) -> Self {
Self {
role,
max: [max_bi, max_uni],
unallocated: [
StreamId::new(role, Dir::Bi, 0),
StreamId::new(role, Dir::Uni, 0),
],
ctrl,
max_tx,
}
}
fn role(&self) -> Role {
self.role
}
fn try_accept_sid(&mut self, sid: StreamId) -> Result<AcceptSid, ExceedLimitError> {
debug_assert_eq!(sid.role(), self.role);
let idx = sid.dir() as usize;
if sid.id() > self.max[idx] {
return Err(ExceedLimitError(sid, self.max[idx]));
}
let cur = &mut self.unallocated[idx];
if sid < *cur {
Ok(AcceptSid::Old)
} else {
let start = *cur;
*cur = unsafe { sid.next_unchecked() };
log::debug!("unallocated: {:?}", self.unallocated[idx]);
if let Some(max_streams) = self.ctrl.on_accept_streams(sid.dir(), sid.id()) {
self.max[idx] = max_streams;
self.max_tx.send_frame([MaxStreamsFrame::with(
sid.dir(),
VarInt::from_u64(max_streams)
.expect("max_streams must be less than VARINT_MAX"),
)]);
}
Ok(AcceptSid::New(NeedCreate { start, end: sid }))
}
}
fn on_end_of_stream(&mut self, sid: StreamId) {
if sid.role() != self.role {
return;
}
if let Some(max_streams) = self.ctrl.on_end_of_stream(sid.dir(), sid.id()) {
self.max[sid.dir() as usize] = max_streams;
self.max_tx.send_frame([MaxStreamsFrame::with(
sid.dir(),
VarInt::from_u64(max_streams).expect("max_streams must be less than VARINT_MAX"),
)]);
}
}
fn recv_streams_blocked_frame(&mut self, frame: &StreamsBlockedFrame) {
let (dir, max_streams) = match frame {
StreamsBlockedFrame::Bi(max) => (Dir::Bi, (*max).into_inner()),
StreamsBlockedFrame::Uni(max) => (Dir::Uni, (*max).into_inner()),
};
if let Some(max_streams) = self.ctrl.on_streams_blocked(dir, max_streams) {
self.max[dir as usize] = max_streams;
self.max_tx.send_frame([MaxStreamsFrame::with(
dir,
VarInt::from_u64(max_streams).expect("max_streams must be less than VARINT_MAX"),
)]);
}
}
}
#[derive(Debug, Clone)]
pub struct ArcRemoteStreamIds<MAX>(Arc<Mutex<RemoteStreamIds<MAX>>>);
impl<MAX> ArcRemoteStreamIds<MAX>
where
MAX: SendFrame<MaxStreamsFrame> + Clone + Send + 'static,
{
pub fn new(
role: Role,
max_bi: u64,
max_uni: u64,
max_tx: MAX,
ctrl: Box<dyn ControlConcurrency>,
) -> Self {
Self(Arc::new(Mutex::new(RemoteStreamIds::new(
role, max_bi, max_uni, max_tx, ctrl,
))))
}
pub fn role(&self) -> Role {
self.0.lock().unwrap().role()
}
pub fn try_accept_sid(&self, sid: StreamId) -> Result<AcceptSid, ExceedLimitError> {
self.0.lock().unwrap().try_accept_sid(sid)
}
#[inline]
pub fn on_end_of_stream(&self, sid: StreamId) {
self.0.lock().unwrap().on_end_of_stream(sid);
}
#[inline]
pub fn recv_streams_blocked_frame(&self, frame: &StreamsBlockedFrame) {
self.0.lock().unwrap().recv_streams_blocked_frame(frame);
}
}
impl<MAX> ReceiveFrame<StreamsBlockedFrame> for ArcRemoteStreamIds<MAX>
where
MAX: SendFrame<MaxStreamsFrame> + Clone + Send + 'static,
{
type Output = ();
fn recv_frame(&self, frame: &StreamsBlockedFrame) -> Result<Self::Output, crate::error::Error> {
self.recv_streams_blocked_frame(frame);
Ok(())
}
}
#[cfg(test)]
mod tests {
use deref_derive::Deref;
use super::*;
use crate::{sid::handy::ConsistentConcurrency, util::ArcAsyncDeque};
#[derive(Clone, Deref, Default)]
struct MaxStreamsFrameTx(ArcAsyncDeque<MaxStreamsFrame>);
impl SendFrame<MaxStreamsFrame> for MaxStreamsFrameTx {
fn send_frame<I: IntoIterator<Item = MaxStreamsFrame>>(&self, iter: I) {
(&self.0).extend(iter);
}
}
#[test]
fn test_try_accept_sid() {
let remote = ArcRemoteStreamIds::new(
Role::Server,
10,
5,
MaxStreamsFrameTx::default(),
Box::new(ConsistentConcurrency::new(10, 5)),
);
let result = remote.try_accept_sid(StreamId(21));
assert_eq!(
result,
Ok(AcceptSid::New(NeedCreate {
start: StreamId(1),
end: StreamId(21)
}))
);
assert_eq!(remote.0.lock().unwrap().unallocated[0], StreamId(25));
let result = remote.try_accept_sid(StreamId(25));
assert_eq!(
result,
Ok(AcceptSid::New(NeedCreate {
start: StreamId(25),
end: StreamId(25)
}))
);
assert_eq!(remote.0.lock().unwrap().unallocated[0], StreamId(29));
let result = remote.try_accept_sid(StreamId(41));
assert_eq!(
result,
Ok(AcceptSid::New(NeedCreate {
start: StreamId(29),
end: StreamId(41)
}))
);
assert_eq!(remote.0.lock().unwrap().unallocated[0], StreamId(45));
if let Ok(AcceptSid::New(mut range)) = result {
assert_eq!(range.next(), Some(StreamId(29)));
assert_eq!(range.next(), Some(StreamId(33)));
assert_eq!(range.next(), Some(StreamId(37)));
assert_eq!(range.next(), Some(StreamId(41)));
assert_eq!(range.next(), None);
}
let result = remote.try_accept_sid(StreamId(65));
assert_eq!(result, Err(ExceedLimitError(StreamId(65), 10)));
}
}