#![doc = include_str!("../README.md")]
use std::num::TryFromIntError;
use std::ops::Index;
use std::time::Duration;
impl<T> Index<T> for HTB<T>
where
usize: From<T>,
{
type Output = u64;
fn index(&self, index: T) -> &Self::Output {
&self.state[usize::from(index)].value
}
}
#[derive(Debug, Copy, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(
feature = "borsh",
derive(borsh::BorshSerialize, borsh::BorshDeserialize)
)]
struct Bucket {
cap: u64,
value: u64,
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(
feature = "borsh",
derive(borsh::BorshSerialize, borsh::BorshDeserialize)
)]
pub struct BucketCfg<T> {
pub this: T,
pub parent: Option<T>,
#[cfg_attr(
feature = "borsh",
borsh(
serialize_with = "borsh_rate_impl::serialize",
deserialize_with = "borsh_rate_impl::deserialize"
)
)]
pub rate: (u64, Duration),
pub capacity: u64,
}
#[cfg(feature = "borsh")]
mod borsh_rate_impl {
pub(crate) fn serialize(
(r, duration): &(u64, core::time::Duration),
writer: &mut impl borsh::io::Write,
) -> borsh::io::Result<()> {
<u64 as borsh::BorshSerialize>::serialize(r, writer)?;
<u64 as borsh::BorshSerialize>::serialize(&duration.as_secs(), writer)?;
<u32 as borsh::BorshSerialize>::serialize(&duration.subsec_nanos(), writer)
}
pub(crate) fn deserialize(
reader: &mut impl borsh::io::Read,
) -> borsh::io::Result<(u64, core::time::Duration)> {
let r = <u64 as borsh::BorshDeserialize>::deserialize_reader(reader)?;
let secs = <u64 as borsh::BorshDeserialize>::deserialize_reader(reader)?;
let nanos = <u32 as borsh::BorshDeserialize>::deserialize_reader(reader)?;
Ok((r, core::time::Duration::new(secs, nanos)))
}
}
#[derive(Clone, Copy, Debug)]
pub enum Error {
NoRoot,
InvalidRate,
InvalidStructure,
}
impl std::error::Error for Error {}
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Error::NoRoot => f.write_str("Problem with a root node of some sort"),
Error::InvalidRate => f.write_str("Requested message rate can't be represented"),
Error::InvalidStructure => f.write_str("Problem with message structure"),
}
}
}
impl From<TryFromIntError> for Error {
fn from(_: TryFromIntError) -> Self {
Error::InvalidRate
}
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(
feature = "borsh",
derive(borsh::BorshSerialize, borsh::BorshDeserialize)
)]
pub struct HTB<T> {
state: Vec<Bucket>,
ops: Vec<Op<T>>,
pub unit_cost: u64,
time_limit: u64,
}
#[derive(Debug, Clone, Copy, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(
feature = "borsh",
derive(borsh::BorshSerialize, borsh::BorshDeserialize)
)]
enum Op<T> {
Inflow(u64),
Take(T, u64),
Deposit(T),
}
fn lcm(a: u128, b: u128) -> u128 {
(a * b) / gcd::Gcd::gcd(a, b)
}
impl<T> HTB<T>
where
T: Copy + Eq + PartialEq,
usize: From<T>,
{
pub fn new(tokens: &[BucketCfg<T>]) -> Result<Self, Error> {
if tokens.is_empty() || tokens[0].parent.is_some() {
return Err(Error::NoRoot);
}
let unit_cost: u64 = tokens
.iter()
.map(|cfg| cfg.rate.1.as_nanos())
.reduce(lcm)
.ok_or(Error::NoRoot)?
.try_into()?;
let rates = tokens
.iter()
.map(|cfg| {
u64::try_from(cfg.rate.0 as u128 * unit_cost as u128 / cfg.rate.1.as_nanos())
})
.collect::<Result<Vec<_>, _>>()?;
let things = tokens.iter().zip(rates.iter().copied()).enumerate();
let mut ops = Vec::new();
let mut items = Vec::new();
let mut stack = Vec::new();
for (ix, (cur, rate)) in things {
if ix != cur.this.into() {
return Err(Error::InvalidStructure);
}
if items.is_empty() && cur.parent.is_some() {
return Err(Error::NoRoot);
}
if cur.capacity as u128 * unit_cost as u128 > usize::MAX as u128 {
return Err(Error::InvalidRate);
}
items.push(Bucket {
cap: cur.capacity * unit_cost,
value: cur.capacity * unit_cost,
});
if cur.parent.as_ref() != stack.last() {
loop {
if let Some(parent) = stack.last() {
if Some(parent) == cur.parent.as_ref() {
ops.push(Op::Deposit(*parent));
break;
}
ops.push(Op::Deposit(*parent));
stack.pop();
} else {
return Err(Error::InvalidStructure);
}
}
}
stack.push(cur.this);
match cur.parent {
Some(parent) => ops.push(Op::Take(parent, rate)),
None => ops.push(Op::Inflow(rate)),
}
}
for leftover in stack.iter().rev().copied() {
ops.push(Op::Deposit(leftover));
}
let limit = unit_cost as u128 * rates.iter().map(|r| *r as u128).sum::<u128>();
if limit > usize::MAX as u128 / 2 {
return Err(Error::InvalidRate);
}
Ok(Self {
unit_cost,
state: items,
ops,
time_limit: limit as u64,
})
}
pub fn drain(&mut self) {
for bucket in self.state.iter_mut() {
bucket.value = 0;
}
}
pub fn refill(&mut self) {
for bucket in self.state.iter_mut() {
bucket.value = bucket.cap;
}
}
pub fn advance_ns(&mut self, time_diff: u64) {
let mut flow = 0u128;
let time_diff = std::cmp::min(time_diff, self.time_limit);
for op in self.ops.iter().copied() {
match op {
Op::Inflow(rate) => flow = rate as u128 * time_diff as u128,
Op::Take(k, rate) => {
let combined = flow + self.state[usize::from(k)].value as u128;
flow = combined.min(rate as u128 * time_diff as u128);
self.state[usize::from(k)].value = (combined - flow) as u64;
}
Op::Deposit(k) => {
let ix = usize::from(k);
let combined = flow + self.state[ix].value as u128;
let deposited = combined.min(self.state[ix].cap as u128);
self.state[ix].value = deposited as u64;
if combined > deposited {
flow = combined - deposited;
} else {
flow = 0;
}
}
}
}
}
pub fn advance(&mut self, time_diff: Duration) {
self.advance_ns(time_diff.as_nanos() as u64);
}
pub fn peek(&self, label: T) -> bool {
self.state[usize::from(label)].value >= self.unit_cost
}
pub fn available(&self, label: T) -> u64 {
self.state[usize::from(label)].value / self.unit_cost
}
pub fn peek_n(&self, label: T, cnt: usize) -> bool {
self.state[usize::from(label)].value >= self.unit_cost * cnt as u64
}
pub fn take(&mut self, label: T) -> bool {
let item = &mut self.state[usize::from(label)];
match item.value.checked_sub(self.unit_cost) {
Some(new) => {
item.value = new;
true
}
None => false,
}
}
pub fn take_n(&mut self, label: T, cnt: usize) -> bool {
let item = &mut self.state[usize::from(label)];
match item.value.checked_sub(self.unit_cost * cnt as u64) {
Some(new) => {
item.value = new;
true
}
None => false,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
enum Rate {
Long,
Short,
Hedge,
HedgeFut,
Make,
}
impl From<Rate> for usize {
fn from(rate: Rate) -> Self {
rate as usize
}
}
fn sample_htb() -> HTB<Rate> {
HTB::new(&[
BucketCfg {
this: Rate::Long,
parent: None,
rate: (100, Duration::from_millis(200)),
capacity: 1500,
},
BucketCfg {
this: Rate::Short,
parent: Some(Rate::Long),
rate: (250, Duration::from_secs(1)),
capacity: 250,
},
BucketCfg {
this: Rate::Hedge,
parent: Some(Rate::Short),
rate: (1000, Duration::from_secs(1)),
capacity: 10,
},
BucketCfg {
this: Rate::HedgeFut,
parent: Some(Rate::Hedge),
rate: (2000, Duration::from_secs(2)),
capacity: 10,
},
BucketCfg {
this: Rate::Make,
parent: Some(Rate::Short),
rate: (1000, Duration::from_secs(1)),
capacity: 6,
},
])
.unwrap()
}
#[test]
fn it_works() {
let mut htb = sample_htb();
assert_eq!(htb.available(Rate::Hedge), 10);
assert!(htb.take_n(Rate::Hedge, 4));
assert_eq!(htb.available(Rate::Hedge), 6);
assert!(htb.take_n(Rate::Hedge, 4));
assert_eq!(htb.available(Rate::Hedge), 2);
assert!(htb.take_n(Rate::Hedge, 2));
assert_eq!(htb.available(Rate::Hedge), 0);
assert!(!htb.take_n(Rate::Hedge, 1));
htb.advance(Duration::from_millis(1));
assert!(htb.peek_n(Rate::Hedge, 1));
assert_eq!(htb.available(Rate::Hedge), 1);
assert!(!htb.peek_n(Rate::Hedge, 2));
assert!(htb.take(Rate::Hedge));
assert!(!htb.take(Rate::Hedge));
htb.advance(Duration::from_millis(5));
assert!(htb.peek_n(Rate::Hedge, 5));
assert!(!htb.peek_n(Rate::Hedge, 6));
htb.advance_ns(u64::MAX / 2);
assert!(htb.take_n(Rate::Hedge, 4));
htb.advance_ns(u64::MAX);
assert!(htb.take_n(Rate::Hedge, 4));
}
}