1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
#![doc = include_str!("../README.md")]
use std::num::TryFromIntError;
use std::ops::Index;
use std::time::Duration;

impl<T> Index<T> for HTB<T>
where
    usize: From<T>,
{
    type Output = u64;

    fn index(&self, index: T) -> &Self::Output {
        &self.state[usize::from(index)].value
    }
}

/// Internal bucket state representation
#[derive(Debug, Copy, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
struct Bucket {
    /// capacity, in abstract units
    cap: u64,
    /// currently contained value
    value: u64,
}

/// Bucket configuration
#[derive(Copy, Clone, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct BucketCfg<T> {
    /// Current bucket name
    pub this: T,
    /// Parent name
    pub parent: Option<T>,
    /// Allowed flow rate in number of tokens per duration
    pub rate: (u64, Duration),
    /// Burst capacity in tokens, can be 0 if burst is not required
    /// at this step.
    ///
    /// If tokens are going to be consumed directly from this bucket
    /// this also limits how granular rate restriction can be.
    ///
    /// For example for a rate of 10 tokens per second
    /// capacity of 10 means you can consume 10 tokens every second
    /// at once or as 10 individual events distributed in any way though
    /// this second.
    /// capacity of 5 gives the same rate of 10 tokens per second bucket
    /// on average but 5 tokens must be consumed in first half of the second
    /// and 5 remaining tokens - in the second half of the second.
    pub capacity: u64,
}

#[derive(Clone, Copy, Debug)]
pub enum Error {
    /// First bucket passed to [`HTB::new`] must be a node with `parent` set to None
    NoRoot,
    /// Calculated flow rate is higher that what can fit into usize
    ///
    /// flow rate is calculated using least common multiplier and if it is very small
    /// HTB ends up sing their product which can overflow. To fix this problem try to tweak
    /// the values to have bigger LCM. For example instead of using 881 and 883 (both are prime
    /// numbers) try using 882
    InvalidRate,

    /// Invalid config passed to HTB:
    ///
    /// Buckets should be given in depth first search traversal order:
    /// - root with `parent` set to None
    /// - higher priority child of the root
    /// - followed by high priority child of the child, if any, etc.
    /// - followed by the next child
    InvalidStructure,
}
impl std::error::Error for Error {}
impl std::fmt::Display for Error {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            Error::NoRoot => f.write_str("Problem with a root node of some sort"),
            Error::InvalidRate => f.write_str("Requested message rate can't be represented"),
            Error::InvalidStructure => f.write_str("Problem with message structure"),
        }
    }
}

impl From<TryFromIntError> for Error {
    fn from(_: TryFromIntError) -> Self {
        Error::InvalidRate
    }
}

/// Hierarchical Token Bucket structure
///
/// You can advance time for HTB structure using [`advance`][Self::advance] and
/// [`advance_ns`][Self::advance_ns] and examine/alter internal state using
/// [`peek`][Self::peek]/[`take`][Self::take].
///
/// When several buckets are feeding from a single parent earlier one gets a priority
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct HTB<T> {
    state: Vec<Bucket>,
    ops: Vec<Op<T>>,
    /// Normalized unit cost, each nanosecond corresponds to this many units
    pub unit_cost: u64,
    /// Maximum time required to refill every possible cell
    time_limit: u64,
}

#[derive(Debug, Clone, Copy, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
enum Op<T> {
    Inflow(u64),
    Take(T, u64),
    Deposit(T),
}

fn lcm(a: u128, b: u128) -> u128 {
    (a * b) / gcd::Gcd::gcd(a, b)
}

impl<T> HTB<T>
where
    T: Copy + Eq + PartialEq,
    usize: From<T>,
{
    /// Create HTB for a given bucket configuration
    ///
    /// Buckets should be given in depth first search traversal order:
    /// - root with `parent` set to None
    /// - higher priority child of the root
    /// - followed by high priority child of the child, if any, etc.
    /// - followed by the next child
    ///
    /// # Errors
    /// If bucket configuration is invalid - returns an [`Error`] type describing a problem
    pub fn new(tokens: &[BucketCfg<T>]) -> Result<Self, Error> {
        if tokens.is_empty() || tokens[0].parent.is_some() {
            return Err(Error::NoRoot);
        }

        // first we need to convert flow rate from items per unit of time
        // to fractions per nanosecond
        let unit_cost: u64 = tokens
            .iter()
            .map(|cfg| cfg.rate.1.as_nanos())
            .reduce(lcm)
            .ok_or(Error::NoRoot)?
            .try_into()?;
        let rates = tokens
            .iter()
            .map(|cfg| {
                u64::try_from(cfg.rate.0 as u128 * unit_cost as u128 / cfg.rate.1.as_nanos())
            })
            .collect::<Result<Vec<_>, _>>()?;

        let things = tokens.iter().zip(rates.iter().copied()).enumerate();

        let mut ops = Vec::new();
        let mut items = Vec::new();
        let mut stack = Vec::new();

        for (ix, (cur, rate)) in things {
            // items must be given in form of depth first traversal
            if ix != cur.this.into() {
                return Err(Error::InvalidStructure);
            }

            // sanity check, first item must be root
            if items.is_empty() && cur.parent.is_some() {
                return Err(Error::NoRoot);
            }

            if cur.capacity as u128 * unit_cost as u128 > usize::MAX as u128 {
                return Err(Error::InvalidRate);
            }

            items.push(Bucket {
                cap: cur.capacity * unit_cost,
                value: cur.capacity * unit_cost,
            });

            if cur.parent.as_ref() != stack.last() {
                loop {
                    if let Some(parent) = stack.last() {
                        if Some(parent) == cur.parent.as_ref() {
                            ops.push(Op::Deposit(*parent));
                            break;
                        }
                        ops.push(Op::Deposit(*parent));
                        stack.pop();
                    } else {
                        return Err(Error::InvalidStructure);
                    }
                }
            }

            stack.push(cur.this);
            match cur.parent {
                Some(parent) => ops.push(Op::Take(parent, rate)),
                None => ops.push(Op::Inflow(rate)),
            }
        }
        for leftover in stack.iter().rev().copied() {
            ops.push(Op::Deposit(leftover));
        }

        let limit = unit_cost as u128 * rates.iter().map(|r| *r as u128).sum::<u128>();
        if limit > usize::MAX as u128 / 2 {
            // In this case is possible for "flow" to overflow the usize
            return Err(Error::InvalidRate);
        }

        Ok(Self {
            unit_cost,
            state: items,
            ops,
            time_limit: limit as u64,
        })
    }

    /// Advance time by number of nanoseconds
    ///
    /// Updates internal structure, see also [`advance`][Self::advance]
    ///
    /// # Performance
    ///
    /// Update cost is O(N) where N is number of buckets
    pub fn advance_ns(&mut self, time_diff: u64) {
        // we start at the top and insert new tokens according to this rules:
        // 1. at most `rate * time_diff` is propagated via links
        // 2. incoming `rate * time_diff` is combined with stored values
        // 3. unused tokens go back and deposited at the previous level
        // 4. at most incoming `rate * time_diff` is propagated back!
        // 5. at most `capacity` is deposited to nodes after the final pass
        let mut flow = 0u128;
        let time_diff = std::cmp::min(time_diff, self.time_limit);
        for op in self.ops.iter().copied() {
            match op {
                Op::Inflow(rate) => flow = rate as u128 * time_diff as u128,
                Op::Take(k, rate) => {
                    let combined = flow + self.state[usize::from(k)].value as u128;
                    flow = combined.min(rate as u128 * time_diff as u128);
                    self.state[usize::from(k)].value = (combined - flow) as u64;
                }
                Op::Deposit(k) => {
                    let ix = usize::from(k);
                    let combined = flow + self.state[ix].value as u128;
                    let deposited = combined.min(self.state[ix].cap as u128);
                    self.state[ix].value = deposited as u64;
                    if combined > deposited {
                        flow = combined - deposited;
                    } else {
                        flow = 0;
                    }
                }
            }
        }
    }

    /// Advance time by [`Duration`]
    ///
    /// Updates internal structure, see also [`advance_ns`][Self::advance_ns]
    pub fn advance(&mut self, time_diff: Duration) {
        self.advance_ns(time_diff.as_nanos() as u64);
    }

    /// Check if there's at least one token available at index `T`
    ///
    /// See also [`peek_n`][Self::peek_n]
    pub fn peek(&self, label: T) -> bool {
        self.state[usize::from(label)].value >= self.unit_cost
    }

    /// Check if there's at least `cnt` tokens available at index `T`
    ///
    /// See also [`peek`][Self::peek]
    pub fn peek_n(&self, label: T, cnt: usize) -> bool {
        self.state[usize::from(label)].value >= self.unit_cost * cnt as u64
    }

    /// Consume a single token from `T`
    ///
    /// See also [`take_n`][Self::take_n]
    pub fn take(&mut self, label: T) -> bool {
        let item = &mut self.state[usize::from(label)];
        match item.value.checked_sub(self.unit_cost) {
            Some(new) => {
                item.value = new;
                true
            }
            None => false,
        }
    }

    /// Consume `cnt` tokens from `T`
    ///
    /// See also [`take`][Self::take]
    pub fn take_n(&mut self, label: T, cnt: usize) -> bool {
        let item = &mut self.state[usize::from(label)];
        match item.value.checked_sub(self.unit_cost * cnt as u64) {
            Some(new) => {
                item.value = new;
                true
            }
            None => false,
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    #[derive(Copy, Clone, Debug, Eq, PartialEq)]
    enum Rate {
        Long,
        Short,
        Hedge,
        HedgeFut,
        Make,
    }

    impl From<Rate> for usize {
        fn from(rate: Rate) -> Self {
            rate as usize
        }
    }

    fn sample_htb() -> HTB<Rate> {
        HTB::new(&[
            BucketCfg {
                this: Rate::Long,
                parent: None,
                rate: (100, Duration::from_millis(200)),
                capacity: 1500,
            },
            BucketCfg {
                this: Rate::Short,
                parent: Some(Rate::Long),
                rate: (250, Duration::from_secs(1)),
                capacity: 250,
            },
            BucketCfg {
                this: Rate::Hedge,
                parent: Some(Rate::Short),
                rate: (1000, Duration::from_secs(1)),
                capacity: 10,
            },
            BucketCfg {
                this: Rate::HedgeFut,
                parent: Some(Rate::Hedge),
                rate: (2000, Duration::from_secs(2)),
                capacity: 10,
            },
            BucketCfg {
                this: Rate::Make,
                parent: Some(Rate::Short),
                rate: (1000, Duration::from_secs(1)),
                capacity: 6,
            },
        ])
        .unwrap()
    }
    #[test]
    fn it_works() {
        let mut htb = sample_htb();
        assert!(htb.take_n(Rate::Hedge, 4));
        assert!(htb.take_n(Rate::Hedge, 4));
        assert!(htb.take_n(Rate::Hedge, 2));
        assert!(!htb.take_n(Rate::Hedge, 1));
        htb.advance(Duration::from_millis(1));
        assert!(htb.peek_n(Rate::Hedge, 1));
        assert!(!htb.peek_n(Rate::Hedge, 2));
        assert!(htb.take(Rate::Hedge));
        assert!(!htb.take(Rate::Hedge));
        htb.advance(Duration::from_millis(5));
        assert!(htb.peek_n(Rate::Hedge, 5));
        assert!(!htb.peek_n(Rate::Hedge, 6));
        htb.advance_ns(u64::MAX / 2);
        assert!(htb.take_n(Rate::Hedge, 4));
        htb.advance_ns(u64::MAX);
        assert!(htb.take_n(Rate::Hedge, 4));
    }
}