use std::{
ops::{Deref, DerefMut},
sync::{Arc, Mutex},
};
use crate::{
error::{Error, ErrorFrameType, ErrorKind, QuicError},
frame::{DataBlockedFrame, FrameType, MaxDataFrame, ReceiveFrame, SendFrame},
net::tx::{ArcSendWakers, Signals},
varint::VarInt,
};
#[derive(Debug)]
struct SendControler<TX> {
sent_data: u64,
max_data: u64,
flow_limited: bool,
broker: TX,
tx_wakers: ArcSendWakers,
}
impl<TX> SendControler<TX> {
fn new(initial_max_data: u64, broker: TX, tx_wakers: ArcSendWakers) -> Self {
Self {
sent_data: 0,
max_data: initial_max_data,
flow_limited: false,
broker,
tx_wakers,
}
}
fn increase_limit(&mut self, max_data: u64) {
if max_data > self.max_data {
self.max_data = max_data;
self.flow_limited = false;
self.tx_wakers.wake_all_by(Signals::FLOW_CONTROL);
}
}
fn avaliable(&self) -> u64 {
self.max_data - self.sent_data
}
fn commit(&mut self, flow: u64)
where
TX: SendFrame<DataBlockedFrame>,
{
self.sent_data += flow;
if self.avaliable() == 0 && !self.flow_limited {
self.flow_limited = true;
self.broker.send_frame([DataBlockedFrame::new(
VarInt::from_u64(self.max_data)
.expect("max_data of flow controller is very very hard to exceed 2^62 - 1"),
)]);
}
}
fn return_back(&mut self, flow: u64) {
self.sent_data -= flow;
if self.avaliable() > 0 {
self.tx_wakers.wake_all_by(Signals::FLOW_CONTROL);
}
}
fn revise_max_data(&mut self, zero_rtt_rejected: bool, max_data: u64) {
if zero_rtt_rejected {
self.max_data = 0;
self.flow_limited = false;
}
self.increase_limit(max_data);
}
}
#[derive(Clone, Debug)]
pub struct ArcSendControler<TX>(Arc<Mutex<Result<SendControler<TX>, Error>>>);
impl<TX> ArcSendControler<TX> {
pub fn new(initial_max_data: u64, broker: TX, tx_wakers: ArcSendWakers) -> Self {
Self(Arc::new(Mutex::new(Ok(SendControler::new(
initial_max_data,
broker,
tx_wakers,
)))))
}
fn increase_limit(&self, max_data: u64) {
let mut guard = self.0.lock().unwrap();
if let Ok(inner) = guard.deref_mut() {
inner.increase_limit(max_data);
}
}
pub fn credit(&self, quota: usize) -> Result<Credit<'_, TX>, Error>
where
TX: SendFrame<DataBlockedFrame>,
{
match self.0.lock().unwrap().as_mut() {
Ok(inner) => {
let avaliable = inner.avaliable().min(quota as u64);
inner.commit(avaliable);
Ok(Credit {
available: avaliable as usize,
controller: self,
})
}
Err(e) => Err(e.clone()),
}
}
pub fn revise_max_data(&self, zero_rtt_rejected: bool, max_data: u64) {
if let Ok(inner) = self.0.lock().unwrap().deref_mut() {
inner.revise_max_data(zero_rtt_rejected, max_data);
}
}
pub fn on_error(&self, error: &Error) {
let mut guard = self.0.lock().unwrap();
if guard.deref().is_err() {
return;
}
*guard = Err(error.clone());
}
}
impl<TX> ReceiveFrame<MaxDataFrame> for ArcSendControler<TX> {
type Output = ();
fn recv_frame(&self, frame: &MaxDataFrame) -> Result<Self::Output, Error> {
self.increase_limit(frame.max_data());
Ok(())
}
}
pub struct Credit<'a, TX> {
available: usize,
controller: &'a ArcSendControler<TX>,
}
impl<TX> Credit<'_, TX> {
pub fn available(&self) -> usize {
self.available
}
}
impl<TX> Credit<'_, TX>
where
TX: SendFrame<DataBlockedFrame>,
{
pub fn post_sent(&mut self, amount: usize) {
self.available -= amount;
}
}
impl<TX> Drop for Credit<'_, TX> {
fn drop(&mut self) {
if let Ok(inner) = self.controller.0.lock().unwrap().as_mut() {
inner.return_back(self.available as u64);
}
}
}
#[derive(Debug, Default)]
struct RecvController<TX> {
rcvd_data: u64,
max_data: u64,
step: u64,
broker: TX,
}
impl<TX> RecvController<TX> {
fn new(initial_max_data: u64, broker: TX) -> Self {
Self {
rcvd_data: 0,
max_data: initial_max_data,
step: initial_max_data / 2,
broker,
}
}
}
impl<TX> RecvController<TX>
where
TX: SendFrame<MaxDataFrame>,
{
fn on_new_rcvd(&mut self, frame_type: FrameType, amount: usize) -> Result<usize, Error> {
self.rcvd_data += amount as u64;
if self.rcvd_data <= self.max_data {
if self.rcvd_data + self.step >= self.max_data {
self.max_data += self.step;
self.broker
.send_frame([MaxDataFrame::new(VarInt::from_u64(self.max_data).expect(
"max_data of flow controller is very very hard to exceed 2^62 - 1",
))])
}
Ok(amount)
} else {
Err(QuicError::new(
ErrorKind::FlowControl,
ErrorFrameType::V1(frame_type),
format!("flow control overflow: {}", self.rcvd_data - self.max_data),
)
.into())
}
}
}
#[derive(Debug, Default, Clone)]
pub struct ArcRecvController<TX>(Arc<Mutex<RecvController<TX>>>);
impl<TX> ArcRecvController<TX> {
pub fn new(initial_max_data: u64, broker: TX) -> Self {
Self(Arc::new(Mutex::new(RecvController::new(
initial_max_data,
broker,
))))
}
}
impl<TX> ArcRecvController<TX>
where
TX: SendFrame<MaxDataFrame>,
{
pub fn on_new_rcvd(&self, frame_type: FrameType, amount: usize) -> Result<usize, Error> {
self.0.lock().unwrap().on_new_rcvd(frame_type, amount)
}
}
impl<TX> ReceiveFrame<DataBlockedFrame> for ArcRecvController<TX> {
type Output = ();
fn recv_frame(&self, _frame: &DataBlockedFrame) -> Result<Self::Output, Error> {
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct FlowController<TX> {
pub sender: ArcSendControler<TX>,
pub recver: ArcRecvController<TX>,
}
impl<TX: Clone> FlowController<TX> {
pub fn new(
peer_initial_max_data: u64,
local_initial_max_data: u64,
broker: TX,
tx_wakers: ArcSendWakers,
) -> Self {
Self {
sender: ArcSendControler::new(peer_initial_max_data, broker.clone(), tx_wakers),
recver: ArcRecvController::new(local_initial_max_data, broker),
}
}
pub fn reset_send_window(&self, snd_wnd: u64) {
self.sender.increase_limit(snd_wnd);
}
pub fn send_limit(&self, quota: usize) -> Result<Credit<'_, TX>, Error>
where
TX: SendFrame<DataBlockedFrame>,
{
self.sender.credit(quota)
}
pub fn on_conn_error(&self, error: &Error) {
self.sender.on_error(error);
}
}
impl<TX> FlowController<TX>
where
TX: SendFrame<MaxDataFrame>,
{
pub fn on_new_rcvd(&self, frame_type: FrameType, amount: usize) -> Result<usize, Error> {
self.recver.on_new_rcvd(frame_type, amount)
}
}
#[cfg(test)]
mod tests {
use derive_more::{Deref, DerefMut};
use super::*;
#[derive(Clone, Debug, Default, Deref, DerefMut)]
struct SendControllerBroker(Arc<Mutex<Vec<DataBlockedFrame>>>);
impl SendFrame<DataBlockedFrame> for SendControllerBroker {
fn send_frame<I: IntoIterator<Item = DataBlockedFrame>>(&self, iter: I) {
self.0.lock().unwrap().extend(iter);
}
}
#[test]
fn test_send_controler() {
let broker = SendControllerBroker::default();
let controler = ArcSendControler::new(0, broker.clone(), Default::default());
controler.increase_limit(100);
let mut credit = controler.credit(200).unwrap();
assert_eq!(credit.available(), 100);
credit.post_sent(50);
assert_eq!(credit.available(), 50);
credit.post_sent(50);
assert_eq!(credit.available(), 0);
drop(credit);
assert_eq!(broker.lock().unwrap().len(), 1);
assert_eq!(broker.lock().unwrap()[0].limit(), 100);
let credit = controler.credit(1).unwrap();
assert_eq!(credit.available(), 0);
drop(credit);
controler.increase_limit(200);
let mut credit = controler.credit(200).unwrap();
assert_eq!(credit.available(), 100);
credit.post_sent(50);
assert_eq!(credit.available(), 50);
credit.post_sent(50);
assert_eq!(credit.available(), 0);
drop(credit);
assert_eq!(broker.lock().unwrap().len(), 2);
assert_eq!(broker.lock().unwrap()[1].limit(), 200);
}
#[derive(Clone, Debug, Default, Deref, DerefMut)]
struct RecvControllerBroker(Arc<Mutex<Vec<MaxDataFrame>>>);
impl SendFrame<MaxDataFrame> for RecvControllerBroker {
fn send_frame<I: IntoIterator<Item = MaxDataFrame>>(&self, iter: I) {
self.0.lock().unwrap().extend(iter);
}
}
#[test]
fn test_recv_controller() {
use crate::frame::{Fin, Flags, Len, Offset};
let broker = RecvControllerBroker::default();
let controler = ArcRecvController::new(100, broker.clone());
let amount = controler
.on_new_rcvd(
FrameType::Stream(Flags(Offset::Zero, Len::Omit, Fin::No)),
20,
)
.unwrap();
assert_eq!(amount, 20);
assert_eq!(broker.lock().unwrap().len(), 0);
let amount = controler
.on_new_rcvd(
FrameType::Stream(Flags(Offset::Zero, Len::Sized, Fin::Yes)),
30,
)
.unwrap();
assert_eq!(amount, 30);
assert_eq!(broker.lock().unwrap().len(), 1);
assert_eq!(broker.lock().unwrap()[0].max_data(), 150);
let result = controler.on_new_rcvd(FrameType::ResetStream, 101);
assert!(result.is_err());
assert_eq!(result.unwrap_err().kind(), ErrorKind::FlowControl);
}
}