use std::{
ops::{Deref, DerefMut},
sync::{
atomic::{AtomicBool, AtomicU64, Ordering},
Arc, Mutex, MutexGuard,
},
task::Waker,
};
use thiserror::Error;
use crate::{
error::Error as QuicError,
frame::{DataBlockedFrame, MaxDataFrame, ReceiveFrame, SendFrame},
varint::VarInt,
};
#[derive(Debug, Default)]
struct SendControler<TX> {
sent_data: u64,
max_data: u64,
block_tx: TX,
wakers: Vec<Waker>,
}
impl<TX> SendControler<TX> {
fn new(initial_max_data: u64, block_tx: TX) -> Self {
Self {
sent_data: 0,
max_data: initial_max_data,
block_tx,
wakers: Vec::with_capacity(4),
}
}
fn register_waker(&mut self, waker: Waker) {
self.wakers.push(waker);
}
fn wake_all(&mut self) {
for waker in self.wakers.drain(..) {
waker.wake();
}
}
fn increase_limit(&mut self, max_data: u64) {
if max_data > self.max_data {
self.max_data = max_data;
for waker in self.wakers.drain(..) {
waker.wake();
}
}
}
}
#[derive(Clone, Debug)]
pub struct ArcSendControler<TX>(Arc<Mutex<Result<SendControler<TX>, QuicError>>>);
impl<TX> ArcSendControler<TX> {
pub fn new(initial_max_data: u64, block_tx: TX) -> Self {
Self(Arc::new(Mutex::new(Ok(SendControler::new(
initial_max_data,
block_tx,
)))))
}
fn increase_limit(&self, max_data: u64) {
let mut guard = self.0.lock().unwrap();
if let Ok(inner) = guard.deref_mut() {
inner.increase_limit(max_data);
}
}
pub fn credit(&self) -> Result<Credit<'_, TX>, QuicError> {
let guard = self.0.lock().unwrap();
if let Err(e) = guard.deref() {
return Err(e.clone());
}
Ok(Credit(guard))
}
pub fn register_waker(&self, waker: Waker) {
let mut guard = self.0.lock().unwrap();
if let Ok(inner) = guard.deref_mut() {
inner.register_waker(waker);
}
}
pub fn on_error(&self, error: &QuicError) {
let mut guard = self.0.lock().unwrap();
if guard.deref().is_err() {
return;
}
if let Ok(inner) = guard.deref_mut() {
inner.wake_all();
}
*guard = Err(error.clone());
}
}
impl<TX> ReceiveFrame<MaxDataFrame> for ArcSendControler<TX> {
type Output = ();
fn recv_frame(&self, frame: &MaxDataFrame) -> Result<Self::Output, QuicError> {
self.increase_limit(frame.max_data.into_inner());
Ok(())
}
}
pub struct Credit<'a, TX>(MutexGuard<'a, Result<SendControler<TX>, QuicError>>);
impl<TX> Credit<'_, TX> {
pub fn available(&self) -> usize {
match self.0.deref() {
Ok(inner) => (inner.max_data - inner.sent_data) as usize,
Err(_) => unreachable!(),
}
}
}
impl<TX> Credit<'_, TX>
where
TX: SendFrame<DataBlockedFrame>,
{
pub fn post_sent(&mut self, amount: usize) {
match self.0.deref_mut() {
Ok(inner) => {
debug_assert!(inner.sent_data + amount as u64 <= inner.max_data);
inner.sent_data += amount as u64;
if inner.sent_data == inner.max_data {
inner.block_tx.send_frame([DataBlockedFrame {
limit: VarInt::from_u64(inner.max_data).expect(
"max_data of flow controller is very very hard to exceed 2^62 - 1",
),
}]);
}
}
Err(_) => unreachable!(),
}
}
}
#[derive(Debug, Clone, Copy, Error)]
#[error("Flow Control exceed {0} bytes on receiving")]
pub struct Overflow(usize);
#[derive(Debug, Default)]
struct RecvController<TX> {
rcvd_data: AtomicU64,
max_data: AtomicU64,
step: u64,
is_closed: AtomicBool,
max_data_tx: TX,
}
impl<TX> RecvController<TX> {
fn new(initial_max_data: u64, max_data_tx: TX) -> Self {
Self {
rcvd_data: AtomicU64::new(0),
max_data: AtomicU64::new(initial_max_data),
step: initial_max_data / 2,
is_closed: AtomicBool::new(false),
max_data_tx,
}
}
fn terminate(&self) {
if !self.is_closed.swap(true, Ordering::Release) {}
}
}
impl<TX> RecvController<TX>
where
TX: SendFrame<MaxDataFrame>,
{
fn on_new_rcvd(&self, amount: usize) -> Result<usize, Overflow> {
debug_assert!(!self.is_closed.load(Ordering::Relaxed));
self.rcvd_data.fetch_add(amount as u64, Ordering::Release);
let rcvd_data = self.rcvd_data.load(Ordering::Acquire);
let max_data = self.max_data.load(Ordering::Acquire);
if rcvd_data <= max_data {
if rcvd_data + self.step >= max_data {
self.max_data.fetch_add(self.step, Ordering::Release);
self.max_data_tx.send_frame([MaxDataFrame {
max_data: VarInt::from_u64(self.max_data.load(Ordering::Acquire))
.expect("max_data of flow controller is very very hard to exceed 2^62 - 1"),
}])
}
Ok(amount)
} else {
Err(Overflow((rcvd_data - max_data) as usize))
}
}
}
#[derive(Debug, Default, Clone)]
pub struct ArcRecvController<TX>(Arc<RecvController<TX>>);
impl<TX> ArcRecvController<TX> {
pub fn new(initial_max_data: u64, max_data_tx: TX) -> Self {
Self(Arc::new(RecvController::new(initial_max_data, max_data_tx)))
}
pub fn terminate(&self) {
self.0.terminate();
}
}
impl<TX> ArcRecvController<TX>
where
TX: SendFrame<MaxDataFrame>,
{
pub fn on_new_rcvd(&self, amount: usize) -> Result<usize, Overflow> {
self.0.on_new_rcvd(amount)
}
}
impl<TX> ReceiveFrame<DataBlockedFrame> for ArcRecvController<TX> {
type Output = ();
fn recv_frame(&self, _frame: &DataBlockedFrame) -> Result<Self::Output, QuicError> {
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct FlowController<TX> {
pub sender: ArcSendControler<TX>,
pub recver: ArcRecvController<TX>,
}
impl<TX: Clone> FlowController<TX> {
pub fn new(peer_initial_max_data: u64, local_initial_max_data: u64, frames_tx: TX) -> Self {
Self {
sender: ArcSendControler::new(peer_initial_max_data, frames_tx.clone()),
recver: ArcRecvController::new(local_initial_max_data, frames_tx),
}
}
pub fn reset_send_window(&self, snd_wnd: u64) {
self.sender.increase_limit(snd_wnd);
}
pub fn send_limit(&self) -> Result<Credit<'_, TX>, QuicError> {
self.sender.credit()
}
pub fn on_conn_error(&self, error: &QuicError) {
self.sender.on_error(error);
self.recver.terminate();
}
}
impl<TX> FlowController<TX>
where
TX: SendFrame<MaxDataFrame>,
{
pub fn on_new_rcvd(&self, amount: usize) -> Result<usize, Overflow> {
self.recver.on_new_rcvd(amount)
}
}