htb/
lib.rs

1#![doc = include_str!("../README.md")]
2use std::num::TryFromIntError;
3use std::ops::Index;
4use std::time::Duration;
5
6impl<T> Index<T> for HTB<T>
7where
8    usize: From<T>,
9{
10    type Output = u64;
11
12    fn index(&self, index: T) -> &Self::Output {
13        &self.state[usize::from(index)].value
14    }
15}
16
17/// Internal bucket state representation
18#[derive(Debug, Copy, Clone, PartialEq)]
19#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
20#[cfg_attr(
21    feature = "borsh",
22    derive(borsh::BorshSerialize, borsh::BorshDeserialize)
23)]
24struct Bucket {
25    /// capacity, in abstract units
26    cap: u64,
27    /// currently contained value
28    value: u64,
29}
30
31/// Bucket configuration
32#[derive(Copy, Clone, Debug, PartialEq, Eq)]
33#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
34#[cfg_attr(
35    feature = "borsh",
36    derive(borsh::BorshSerialize, borsh::BorshDeserialize)
37)]
38pub struct BucketCfg<T> {
39    /// Current bucket name
40    pub this: T,
41    /// Parent name
42    pub parent: Option<T>,
43    /// Allowed flow rate in number of tokens per duration
44    #[cfg_attr(
45        feature = "borsh",
46        borsh(
47            serialize_with = "borsh_rate_impl::serialize",
48            deserialize_with = "borsh_rate_impl::deserialize"
49        )
50    )]
51    pub rate: (u64, Duration),
52    /// Burst capacity in tokens, can be 0 if burst is not required
53    /// at this step.
54    ///
55    /// If tokens are going to be consumed directly from this bucket
56    /// this also limits how granular rate restriction can be.
57    ///
58    /// For example for a rate of 10 tokens per second
59    /// capacity of 10 means you can consume 10 tokens every second
60    /// at once or as 10 individual events distributed in any way though
61    /// this second.
62    /// capacity of 5 gives the same rate of 10 tokens per second bucket
63    /// on average but 5 tokens must be consumed in first half of the second
64    /// and 5 remaining tokens - in the second half of the second.
65    pub capacity: u64,
66}
67
68#[cfg(feature = "borsh")]
69/// Borsh serialisation for [`core::time::Duration`] without having to depend on
70/// borsh upstream adding support for this type.
71mod borsh_rate_impl {
72    pub(crate) fn serialize(
73        (r, duration): &(u64, core::time::Duration),
74        writer: &mut impl borsh::io::Write,
75    ) -> borsh::io::Result<()> {
76        <u64 as borsh::BorshSerialize>::serialize(r, writer)?;
77        <u64 as borsh::BorshSerialize>::serialize(&duration.as_secs(), writer)?;
78        <u32 as borsh::BorshSerialize>::serialize(&duration.subsec_nanos(), writer)
79    }
80
81    pub(crate) fn deserialize(
82        reader: &mut impl borsh::io::Read,
83    ) -> borsh::io::Result<(u64, core::time::Duration)> {
84        let r = <u64 as borsh::BorshDeserialize>::deserialize_reader(reader)?;
85        let secs = <u64 as borsh::BorshDeserialize>::deserialize_reader(reader)?;
86        let nanos = <u32 as borsh::BorshDeserialize>::deserialize_reader(reader)?;
87        Ok((r, core::time::Duration::new(secs, nanos)))
88    }
89}
90
91#[derive(Clone, Copy, Debug)]
92pub enum Error {
93    /// First bucket passed to [`HTB::new`] must be a node with `parent` set to None
94    NoRoot,
95    /// Calculated flow rate is higher that what can fit into usize
96    ///
97    /// flow rate is calculated using least common multiplier and if it is very small
98    /// HTB ends up sing their product which can overflow. To fix this problem try to tweak
99    /// the values to have bigger LCM. For example instead of using 881 and 883 (both are prime
100    /// numbers) try using 882
101    InvalidRate,
102
103    /// Invalid config passed to HTB:
104    ///
105    /// Buckets should be given in depth first search traversal order:
106    /// - root with `parent` set to None
107    /// - higher priority child of the root
108    /// - followed by high priority child of the child, if any, etc.
109    /// - followed by the next child
110    InvalidStructure,
111}
112impl std::error::Error for Error {}
113impl std::fmt::Display for Error {
114    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115        match self {
116            Error::NoRoot => f.write_str("Problem with a root node of some sort"),
117            Error::InvalidRate => f.write_str("Requested message rate can't be represented"),
118            Error::InvalidStructure => f.write_str("Problem with message structure"),
119        }
120    }
121}
122
123impl From<TryFromIntError> for Error {
124    fn from(_: TryFromIntError) -> Self {
125        Error::InvalidRate
126    }
127}
128
129/// Hierarchical Token Bucket structure
130///
131/// You can advance time for HTB structure using [`advance`][Self::advance] and
132/// [`advance_ns`][Self::advance_ns] and examine/alter internal state using
133/// [`peek`][Self::peek]/[`take`][Self::take].
134///
135/// When several buckets are feeding from a single parent earlier one gets a priority
136#[derive(Debug, Clone, PartialEq)]
137#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
138#[cfg_attr(
139    feature = "borsh",
140    derive(borsh::BorshSerialize, borsh::BorshDeserialize)
141)]
142pub struct HTB<T> {
143    state: Vec<Bucket>,
144    ops: Vec<Op<T>>,
145    /// Normalized unit cost, each nanosecond corresponds to this many units
146    pub unit_cost: u64,
147    /// Maximum time required to refill every possible cell
148    time_limit: u64,
149}
150
151#[derive(Debug, Clone, Copy, PartialEq)]
152#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
153#[cfg_attr(
154    feature = "borsh",
155    derive(borsh::BorshSerialize, borsh::BorshDeserialize)
156)]
157enum Op<T> {
158    Inflow(u64),
159    Take(T, u64),
160    Deposit(T),
161}
162
163fn lcm(a: u128, b: u128) -> u128 {
164    (a * b) / gcd::Gcd::gcd(a, b)
165}
166
167impl<T> HTB<T>
168where
169    T: Copy + Eq + PartialEq,
170    usize: From<T>,
171{
172    /// Create HTB for a given bucket configuration
173    ///
174    /// Buckets should be given in depth first search traversal order:
175    /// - root with `parent` set to None
176    /// - higher priority child of the root
177    /// - followed by high priority child of the child, if any, etc.
178    /// - followed by the next child
179    ///
180    /// # Errors
181    /// If bucket configuration is invalid - returns an [`Error`] type describing a problem
182    pub fn new(tokens: &[BucketCfg<T>]) -> Result<Self, Error> {
183        if tokens.is_empty() || tokens[0].parent.is_some() {
184            return Err(Error::NoRoot);
185        }
186
187        // first we need to convert flow rate from items per unit of time
188        // to fractions per nanosecond
189        let unit_cost: u64 = tokens
190            .iter()
191            .map(|cfg| cfg.rate.1.as_nanos())
192            .reduce(lcm)
193            .ok_or(Error::NoRoot)?
194            .try_into()?;
195        let rates = tokens
196            .iter()
197            .map(|cfg| {
198                u64::try_from(cfg.rate.0 as u128 * unit_cost as u128 / cfg.rate.1.as_nanos())
199            })
200            .collect::<Result<Vec<_>, _>>()?;
201
202        let things = tokens.iter().zip(rates.iter().copied()).enumerate();
203
204        let mut ops = Vec::new();
205        let mut items = Vec::new();
206        let mut stack = Vec::new();
207
208        for (ix, (cur, rate)) in things {
209            // items must be given in form of depth first traversal
210            if ix != cur.this.into() {
211                return Err(Error::InvalidStructure);
212            }
213
214            // sanity check, first item must be root
215            if items.is_empty() && cur.parent.is_some() {
216                return Err(Error::NoRoot);
217            }
218
219            if cur.capacity as u128 * unit_cost as u128 > usize::MAX as u128 {
220                return Err(Error::InvalidRate);
221            }
222
223            items.push(Bucket {
224                cap: cur.capacity * unit_cost,
225                value: cur.capacity * unit_cost,
226            });
227
228            if cur.parent.as_ref() != stack.last() {
229                loop {
230                    if let Some(parent) = stack.last() {
231                        if Some(parent) == cur.parent.as_ref() {
232                            ops.push(Op::Deposit(*parent));
233                            break;
234                        }
235                        ops.push(Op::Deposit(*parent));
236                        stack.pop();
237                    } else {
238                        return Err(Error::InvalidStructure);
239                    }
240                }
241            }
242
243            stack.push(cur.this);
244            match cur.parent {
245                Some(parent) => ops.push(Op::Take(parent, rate)),
246                None => ops.push(Op::Inflow(rate)),
247            }
248        }
249        for leftover in stack.iter().rev().copied() {
250            ops.push(Op::Deposit(leftover));
251        }
252
253        let limit = unit_cost as u128 * rates.iter().map(|r| *r as u128).sum::<u128>();
254        if limit > usize::MAX as u128 / 2 {
255            // In this case is possible for "flow" to overflow the usize
256            return Err(Error::InvalidRate);
257        }
258
259        Ok(Self {
260            unit_cost,
261            state: items,
262            ops,
263            time_limit: limit as u64,
264        })
265    }
266
267    /// Drain all the available tokens
268    pub fn drain(&mut self) {
269        for bucket in self.state.iter_mut() {
270            bucket.value = 0;
271        }
272    }
273
274    /// Refill all the available buckets to full capacity
275    pub fn refill(&mut self) {
276        for bucket in self.state.iter_mut() {
277            bucket.value = bucket.cap;
278        }
279    }
280
281    /// Advance time by number of nanoseconds
282    ///
283    /// Updates internal structure, see also [`advance`][Self::advance]
284    ///
285    /// # Performance
286    ///
287    /// Update cost is O(N) where N is number of buckets
288    pub fn advance_ns(&mut self, time_diff: u64) {
289        // we start at the top and insert new tokens according to this rules:
290        // 1. at most `rate * time_diff` is propagated via links
291        // 2. incoming `rate * time_diff` is combined with stored values
292        // 3. unused tokens go back and deposited at the previous level
293        // 4. at most incoming `rate * time_diff` is propagated back!
294        // 5. at most `capacity` is deposited to nodes after the final pass
295        let mut flow = 0u128;
296        let time_diff = std::cmp::min(time_diff, self.time_limit);
297        for op in self.ops.iter().copied() {
298            match op {
299                Op::Inflow(rate) => flow = rate as u128 * time_diff as u128,
300                Op::Take(k, rate) => {
301                    let combined = flow + self.state[usize::from(k)].value as u128;
302                    flow = combined.min(rate as u128 * time_diff as u128);
303                    self.state[usize::from(k)].value = (combined - flow) as u64;
304                }
305                Op::Deposit(k) => {
306                    let ix = usize::from(k);
307                    let combined = flow + self.state[ix].value as u128;
308                    let deposited = combined.min(self.state[ix].cap as u128);
309                    self.state[ix].value = deposited as u64;
310                    if combined > deposited {
311                        flow = combined - deposited;
312                    } else {
313                        flow = 0;
314                    }
315                }
316            }
317        }
318    }
319
320    /// Advance time by [`Duration`]
321    ///
322    /// Updates internal structure, see also [`advance_ns`][Self::advance_ns]
323    pub fn advance(&mut self, time_diff: Duration) {
324        self.advance_ns(time_diff.as_nanos() as u64);
325    }
326
327    /// Check if there's at least one token available at index `T`
328    ///
329    /// See also [`peek_n`][Self::peek_n]
330    pub fn peek(&self, label: T) -> bool {
331        self.state[usize::from(label)].value >= self.unit_cost
332    }
333
334    /// Shows how many tokens are available at
335    pub fn available(&self, label: T) -> u64 {
336        self.state[usize::from(label)].value / self.unit_cost
337    }
338
339    /// Check if there's at least `cnt` tokens available at index `T`
340    ///
341    /// See also [`peek`][Self::peek]
342    pub fn peek_n(&self, label: T, cnt: usize) -> bool {
343        self.state[usize::from(label)].value >= self.unit_cost * cnt as u64
344    }
345
346    /// Consume a single token from `T`
347    ///
348    /// See also [`take_n`][Self::take_n]
349    pub fn take(&mut self, label: T) -> bool {
350        let item = &mut self.state[usize::from(label)];
351        match item.value.checked_sub(self.unit_cost) {
352            Some(new) => {
353                item.value = new;
354                true
355            }
356            None => false,
357        }
358    }
359
360    /// Consume `cnt` tokens from `T`
361    ///
362    /// See also [`take`][Self::take]
363    pub fn take_n(&mut self, label: T, cnt: usize) -> bool {
364        let item = &mut self.state[usize::from(label)];
365        match item.value.checked_sub(self.unit_cost * cnt as u64) {
366            Some(new) => {
367                item.value = new;
368                true
369            }
370            None => false,
371        }
372    }
373}
374
375#[cfg(test)]
376mod tests {
377    use super::*;
378    #[derive(Copy, Clone, Debug, Eq, PartialEq)]
379    enum Rate {
380        Long,
381        Short,
382        Hedge,
383        HedgeFut,
384        Make,
385    }
386
387    impl From<Rate> for usize {
388        fn from(rate: Rate) -> Self {
389            rate as usize
390        }
391    }
392
393    fn sample_htb() -> HTB<Rate> {
394        HTB::new(&[
395            BucketCfg {
396                this: Rate::Long,
397                parent: None,
398                rate: (100, Duration::from_millis(200)),
399                capacity: 1500,
400            },
401            BucketCfg {
402                this: Rate::Short,
403                parent: Some(Rate::Long),
404                rate: (250, Duration::from_secs(1)),
405                capacity: 250,
406            },
407            BucketCfg {
408                this: Rate::Hedge,
409                parent: Some(Rate::Short),
410                rate: (1000, Duration::from_secs(1)),
411                capacity: 10,
412            },
413            BucketCfg {
414                this: Rate::HedgeFut,
415                parent: Some(Rate::Hedge),
416                rate: (2000, Duration::from_secs(2)),
417                capacity: 10,
418            },
419            BucketCfg {
420                this: Rate::Make,
421                parent: Some(Rate::Short),
422                rate: (1000, Duration::from_secs(1)),
423                capacity: 6,
424            },
425        ])
426        .unwrap()
427    }
428    #[test]
429    fn it_works() {
430        let mut htb = sample_htb();
431        assert_eq!(htb.available(Rate::Hedge), 10);
432        assert!(htb.take_n(Rate::Hedge, 4));
433        assert_eq!(htb.available(Rate::Hedge), 6);
434        assert!(htb.take_n(Rate::Hedge, 4));
435        assert_eq!(htb.available(Rate::Hedge), 2);
436        assert!(htb.take_n(Rate::Hedge, 2));
437        assert_eq!(htb.available(Rate::Hedge), 0);
438        assert!(!htb.take_n(Rate::Hedge, 1));
439        htb.advance(Duration::from_millis(1));
440        assert!(htb.peek_n(Rate::Hedge, 1));
441        assert_eq!(htb.available(Rate::Hedge), 1);
442        assert!(!htb.peek_n(Rate::Hedge, 2));
443        assert!(htb.take(Rate::Hedge));
444        assert!(!htb.take(Rate::Hedge));
445        htb.advance(Duration::from_millis(5));
446        assert!(htb.peek_n(Rate::Hedge, 5));
447        assert!(!htb.peek_n(Rate::Hedge, 6));
448        htb.advance_ns(u64::MAX / 2);
449        assert!(htb.take_n(Rate::Hedge, 4));
450        htb.advance_ns(u64::MAX);
451        assert!(htb.take_n(Rate::Hedge, 4));
452    }
453}