use std::sync::atomic::{AtomicU8, AtomicUsize, Ordering};
use qbase::net::tx::{ArcSendWaker, Signals};
pub const DEFAULT_ANTI_FACTOR: usize = 3;
#[derive(Debug)]
pub struct AntiAmplifier<const N: usize = DEFAULT_ANTI_FACTOR> {
credit: AtomicUsize,
tx_waker: ArcSendWaker,
state: AtomicU8,
}
impl<const N: usize> AntiAmplifier<N> {
const NORMAL: u8 = 0;
const GRANTED: u8 = 1;
const ABORTED: u8 = 2;
pub fn new(tx_waker: ArcSendWaker) -> Self {
Self {
credit: AtomicUsize::new(0),
tx_waker,
state: AtomicU8::new(0),
}
}
pub fn on_rcvd(&self, amount: usize) {
if self.state.load(Ordering::Acquire) != Self::NORMAL {
return;
}
self.credit.fetch_add(amount * N, Ordering::AcqRel);
self.tx_waker.wake_by(Signals::CREDIT);
}
pub fn balance(&self) -> Result<Option<usize>, Signals> {
match self.state.load(Ordering::Acquire) {
Self::GRANTED => Ok(Some(usize::MAX)),
Self::ABORTED => Ok(None),
Self::NORMAL => {
let credit = self.credit.load(Ordering::Acquire);
if credit == 0 {
let state = self.state.load(Ordering::Acquire);
if state == Self::NORMAL {
Err(Signals::CREDIT)
} else {
self.tx_waker.wake_by(Signals::CREDIT);
if state == Self::GRANTED {
Ok(Some(usize::MAX))
} else {
Ok(None)
}
}
} else {
Ok(Some(credit))
}
}
_ => unreachable!(),
}
}
pub fn on_sent(&self, amount: usize) {
if self.state.load(Ordering::Acquire) == Self::NORMAL {
self.credit.fetch_sub(amount, Ordering::AcqRel);
}
}
pub fn grant(&self) {
if self
.state
.compare_exchange(
Self::NORMAL,
Self::GRANTED,
Ordering::AcqRel,
Ordering::Acquire,
)
.is_ok()
{
self.tx_waker.wake_by(Signals::CREDIT);
}
}
pub fn abort(&self) {
if self
.state
.compare_exchange(
Self::NORMAL,
Self::ABORTED,
Ordering::AcqRel,
Ordering::Acquire,
)
.is_ok()
{
self.tx_waker.wake_by(Signals::CREDIT);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_deposit_and_poll_apply() {
let waker = ArcSendWaker::new();
let anti_amplifier = AntiAmplifier::<3>::new(waker);
assert_eq!(anti_amplifier.balance(), Err(Signals::CREDIT));
anti_amplifier.on_rcvd(1);
assert_eq!(anti_amplifier.credit.load(Ordering::Acquire), 3);
assert_eq!(anti_amplifier.balance(), Ok(Some(3)));
assert_eq!(anti_amplifier.credit.load(Ordering::Acquire), 3);
anti_amplifier.on_sent(3);
assert_eq!(anti_amplifier.balance(), Err(Signals::CREDIT));
}
#[test]
fn test_multiple_deposits() {
let waker = ArcSendWaker::new();
let anti_amplifier = AntiAmplifier::<3>::new(waker);
anti_amplifier.on_rcvd(1);
assert_eq!(anti_amplifier.credit.load(Ordering::Acquire), 3);
anti_amplifier.on_rcvd(1);
assert_eq!(anti_amplifier.credit.load(Ordering::Acquire), 6);
assert_eq!(anti_amplifier.balance(), Ok(Some(6)));
assert_eq!(anti_amplifier.credit.load(Ordering::Acquire), 6);
anti_amplifier.on_sent(5);
assert_eq!(anti_amplifier.credit.load(Ordering::Acquire), 1);
}
}