1#![doc = include_str!("../README.md")]
2use std::num::TryFromIntError;
3use std::ops::Index;
4use std::time::Duration;
5
6impl<T> Index<T> for HTB<T>
7where
8 usize: From<T>,
9{
10 type Output = u64;
11
12 fn index(&self, index: T) -> &Self::Output {
13 &self.state[usize::from(index)].value
14 }
15}
16
17#[derive(Debug, Copy, Clone, PartialEq)]
19#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
20#[cfg_attr(
21 feature = "borsh",
22 derive(borsh::BorshSerialize, borsh::BorshDeserialize)
23)]
24struct Bucket {
25 cap: u64,
27 value: u64,
29}
30
31#[derive(Copy, Clone, Debug, PartialEq, Eq)]
33#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
34#[cfg_attr(
35 feature = "borsh",
36 derive(borsh::BorshSerialize, borsh::BorshDeserialize)
37)]
38pub struct BucketCfg<T> {
39 pub this: T,
41 pub parent: Option<T>,
43 #[cfg_attr(
45 feature = "borsh",
46 borsh(
47 serialize_with = "borsh_rate_impl::serialize",
48 deserialize_with = "borsh_rate_impl::deserialize"
49 )
50 )]
51 pub rate: (u64, Duration),
52 pub capacity: u64,
66}
67
68#[cfg(feature = "borsh")]
69mod borsh_rate_impl {
72 pub(crate) fn serialize(
73 (r, duration): &(u64, core::time::Duration),
74 writer: &mut impl borsh::io::Write,
75 ) -> borsh::io::Result<()> {
76 <u64 as borsh::BorshSerialize>::serialize(r, writer)?;
77 <u64 as borsh::BorshSerialize>::serialize(&duration.as_secs(), writer)?;
78 <u32 as borsh::BorshSerialize>::serialize(&duration.subsec_nanos(), writer)
79 }
80
81 pub(crate) fn deserialize(
82 reader: &mut impl borsh::io::Read,
83 ) -> borsh::io::Result<(u64, core::time::Duration)> {
84 let r = <u64 as borsh::BorshDeserialize>::deserialize_reader(reader)?;
85 let secs = <u64 as borsh::BorshDeserialize>::deserialize_reader(reader)?;
86 let nanos = <u32 as borsh::BorshDeserialize>::deserialize_reader(reader)?;
87 Ok((r, core::time::Duration::new(secs, nanos)))
88 }
89}
90
91#[derive(Clone, Copy, Debug)]
92pub enum Error {
93 NoRoot,
95 InvalidRate,
102
103 InvalidStructure,
111}
112impl std::error::Error for Error {}
113impl std::fmt::Display for Error {
114 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115 match self {
116 Error::NoRoot => f.write_str("Problem with a root node of some sort"),
117 Error::InvalidRate => f.write_str("Requested message rate can't be represented"),
118 Error::InvalidStructure => f.write_str("Problem with message structure"),
119 }
120 }
121}
122
123impl From<TryFromIntError> for Error {
124 fn from(_: TryFromIntError) -> Self {
125 Error::InvalidRate
126 }
127}
128
129#[derive(Debug, Clone, PartialEq)]
137#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
138#[cfg_attr(
139 feature = "borsh",
140 derive(borsh::BorshSerialize, borsh::BorshDeserialize)
141)]
142pub struct HTB<T> {
143 state: Vec<Bucket>,
144 ops: Vec<Op<T>>,
145 pub unit_cost: u64,
147 time_limit: u64,
149}
150
151#[derive(Debug, Clone, Copy, PartialEq)]
152#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
153#[cfg_attr(
154 feature = "borsh",
155 derive(borsh::BorshSerialize, borsh::BorshDeserialize)
156)]
157enum Op<T> {
158 Inflow(u64),
159 Take(T, u64),
160 Deposit(T),
161}
162
163fn lcm(a: u128, b: u128) -> u128 {
164 (a * b) / gcd::Gcd::gcd(a, b)
165}
166
167impl<T> HTB<T>
168where
169 T: Copy + Eq + PartialEq,
170 usize: From<T>,
171{
172 pub fn new(tokens: &[BucketCfg<T>]) -> Result<Self, Error> {
183 if tokens.is_empty() || tokens[0].parent.is_some() {
184 return Err(Error::NoRoot);
185 }
186
187 let unit_cost: u64 = tokens
190 .iter()
191 .map(|cfg| cfg.rate.1.as_nanos())
192 .reduce(lcm)
193 .ok_or(Error::NoRoot)?
194 .try_into()?;
195 let rates = tokens
196 .iter()
197 .map(|cfg| {
198 u64::try_from(cfg.rate.0 as u128 * unit_cost as u128 / cfg.rate.1.as_nanos())
199 })
200 .collect::<Result<Vec<_>, _>>()?;
201
202 let things = tokens.iter().zip(rates.iter().copied()).enumerate();
203
204 let mut ops = Vec::new();
205 let mut items = Vec::new();
206 let mut stack = Vec::new();
207
208 for (ix, (cur, rate)) in things {
209 if ix != cur.this.into() {
211 return Err(Error::InvalidStructure);
212 }
213
214 if items.is_empty() && cur.parent.is_some() {
216 return Err(Error::NoRoot);
217 }
218
219 if cur.capacity as u128 * unit_cost as u128 > usize::MAX as u128 {
220 return Err(Error::InvalidRate);
221 }
222
223 items.push(Bucket {
224 cap: cur.capacity * unit_cost,
225 value: cur.capacity * unit_cost,
226 });
227
228 if cur.parent.as_ref() != stack.last() {
229 loop {
230 if let Some(parent) = stack.last() {
231 if Some(parent) == cur.parent.as_ref() {
232 ops.push(Op::Deposit(*parent));
233 break;
234 }
235 ops.push(Op::Deposit(*parent));
236 stack.pop();
237 } else {
238 return Err(Error::InvalidStructure);
239 }
240 }
241 }
242
243 stack.push(cur.this);
244 match cur.parent {
245 Some(parent) => ops.push(Op::Take(parent, rate)),
246 None => ops.push(Op::Inflow(rate)),
247 }
248 }
249 for leftover in stack.iter().rev().copied() {
250 ops.push(Op::Deposit(leftover));
251 }
252
253 let limit = unit_cost as u128 * rates.iter().map(|r| *r as u128).sum::<u128>();
254 if limit > usize::MAX as u128 / 2 {
255 return Err(Error::InvalidRate);
257 }
258
259 Ok(Self {
260 unit_cost,
261 state: items,
262 ops,
263 time_limit: limit as u64,
264 })
265 }
266
267 pub fn drain(&mut self) {
269 for bucket in self.state.iter_mut() {
270 bucket.value = 0;
271 }
272 }
273
274 pub fn refill(&mut self) {
276 for bucket in self.state.iter_mut() {
277 bucket.value = bucket.cap;
278 }
279 }
280
281 pub fn advance_ns(&mut self, time_diff: u64) {
289 let mut flow = 0u128;
296 let time_diff = std::cmp::min(time_diff, self.time_limit);
297 for op in self.ops.iter().copied() {
298 match op {
299 Op::Inflow(rate) => flow = rate as u128 * time_diff as u128,
300 Op::Take(k, rate) => {
301 let combined = flow + self.state[usize::from(k)].value as u128;
302 flow = combined.min(rate as u128 * time_diff as u128);
303 self.state[usize::from(k)].value = (combined - flow) as u64;
304 }
305 Op::Deposit(k) => {
306 let ix = usize::from(k);
307 let combined = flow + self.state[ix].value as u128;
308 let deposited = combined.min(self.state[ix].cap as u128);
309 self.state[ix].value = deposited as u64;
310 if combined > deposited {
311 flow = combined - deposited;
312 } else {
313 flow = 0;
314 }
315 }
316 }
317 }
318 }
319
320 pub fn advance(&mut self, time_diff: Duration) {
324 self.advance_ns(time_diff.as_nanos() as u64);
325 }
326
327 pub fn peek(&self, label: T) -> bool {
331 self.state[usize::from(label)].value >= self.unit_cost
332 }
333
334 pub fn available(&self, label: T) -> u64 {
336 self.state[usize::from(label)].value / self.unit_cost
337 }
338
339 pub fn peek_n(&self, label: T, cnt: usize) -> bool {
343 self.state[usize::from(label)].value >= self.unit_cost * cnt as u64
344 }
345
346 pub fn take(&mut self, label: T) -> bool {
350 let item = &mut self.state[usize::from(label)];
351 match item.value.checked_sub(self.unit_cost) {
352 Some(new) => {
353 item.value = new;
354 true
355 }
356 None => false,
357 }
358 }
359
360 pub fn take_n(&mut self, label: T, cnt: usize) -> bool {
364 let item = &mut self.state[usize::from(label)];
365 match item.value.checked_sub(self.unit_cost * cnt as u64) {
366 Some(new) => {
367 item.value = new;
368 true
369 }
370 None => false,
371 }
372 }
373}
374
375#[cfg(test)]
376mod tests {
377 use super::*;
378 #[derive(Copy, Clone, Debug, Eq, PartialEq)]
379 enum Rate {
380 Long,
381 Short,
382 Hedge,
383 HedgeFut,
384 Make,
385 }
386
387 impl From<Rate> for usize {
388 fn from(rate: Rate) -> Self {
389 rate as usize
390 }
391 }
392
393 fn sample_htb() -> HTB<Rate> {
394 HTB::new(&[
395 BucketCfg {
396 this: Rate::Long,
397 parent: None,
398 rate: (100, Duration::from_millis(200)),
399 capacity: 1500,
400 },
401 BucketCfg {
402 this: Rate::Short,
403 parent: Some(Rate::Long),
404 rate: (250, Duration::from_secs(1)),
405 capacity: 250,
406 },
407 BucketCfg {
408 this: Rate::Hedge,
409 parent: Some(Rate::Short),
410 rate: (1000, Duration::from_secs(1)),
411 capacity: 10,
412 },
413 BucketCfg {
414 this: Rate::HedgeFut,
415 parent: Some(Rate::Hedge),
416 rate: (2000, Duration::from_secs(2)),
417 capacity: 10,
418 },
419 BucketCfg {
420 this: Rate::Make,
421 parent: Some(Rate::Short),
422 rate: (1000, Duration::from_secs(1)),
423 capacity: 6,
424 },
425 ])
426 .unwrap()
427 }
428 #[test]
429 fn it_works() {
430 let mut htb = sample_htb();
431 assert_eq!(htb.available(Rate::Hedge), 10);
432 assert!(htb.take_n(Rate::Hedge, 4));
433 assert_eq!(htb.available(Rate::Hedge), 6);
434 assert!(htb.take_n(Rate::Hedge, 4));
435 assert_eq!(htb.available(Rate::Hedge), 2);
436 assert!(htb.take_n(Rate::Hedge, 2));
437 assert_eq!(htb.available(Rate::Hedge), 0);
438 assert!(!htb.take_n(Rate::Hedge, 1));
439 htb.advance(Duration::from_millis(1));
440 assert!(htb.peek_n(Rate::Hedge, 1));
441 assert_eq!(htb.available(Rate::Hedge), 1);
442 assert!(!htb.peek_n(Rate::Hedge, 2));
443 assert!(htb.take(Rate::Hedge));
444 assert!(!htb.take(Rate::Hedge));
445 htb.advance(Duration::from_millis(5));
446 assert!(htb.peek_n(Rate::Hedge, 5));
447 assert!(!htb.peek_n(Rate::Hedge, 6));
448 htb.advance_ns(u64::MAX / 2);
449 assert!(htb.take_n(Rate::Hedge, 4));
450 htb.advance_ns(u64::MAX);
451 assert!(htb.take_n(Rate::Hedge, 4));
452 }
453}