midas_rs/
lib.rs

1//! Rust implementation of
2//! [https://github.com/bhatiasiddharth/MIDAS](https://github.com/bhatiasiddharth/MIDAS)
3//!
4//! ```rust
5//! use midas_rs::{Int, Float, MidasR};
6//!
7//! fn main() {
8//!     // For configuration options, refer to MidasRParams
9//!     let mut midas = MidasR::new(Default::default());
10//!
11//!     println!("{:.6}", midas.insert((1, 1, 1)));
12//!     println!("{:.6}", midas.insert((1, 2, 1)));
13//!     println!("{:.6}", midas.insert((1, 1, 2)));
14//!     println!("{:.6}", midas.insert((1, 2, 3)));
15//!
16//!     assert_eq!(midas.insert((1, 2, 4)), midas.query(1, 2));
17//! }
18//! ```
19
20use rand::rngs::SmallRng;
21
22pub mod default {
23    use super::{Float, Int};
24
25    pub const NUM_ROWS: Int = 2;
26    pub const NUM_BUCKETS: Int = 769;
27    pub const M_VALUE: Int = 773;
28    pub const ALPHA: Float = 0.6;
29}
30
31pub type Int = u64;
32pub type Float = f64;
33const FLOAT_MAX: Float = std::f64::MAX;
34
35struct Row {
36    a: Int,
37    b: Int,
38    buckets: Vec<Float>,
39}
40
41impl Row {
42    fn new(buckets: Int, rng: &mut Rng) -> Self {
43        Self {
44            a: (rng.rand() % (buckets - 1)) + 1,
45            b: rng.rand() % buckets,
46            buckets: vec![0.; buckets as usize],
47        }
48    }
49
50    fn hash(&self, m_value: Int, source: Int, dest: Int) -> Int {
51        #![allow(unused_comparisons)]
52
53        let resid = m_value
54            .wrapping_mul(dest)
55            .wrapping_add(source)
56            .wrapping_mul(self.a)
57            .wrapping_add(self.b)
58            % self.num_buckets() as Int;
59
60        resid
61            + if resid < 0 {
62                self.num_buckets() as Int
63            } else {
64                0
65            }
66    }
67
68    fn node_insert(&mut self, a: Int, weight: Float) {
69        self.insert(0, a, 0, weight)
70    }
71
72    fn insert(&mut self, m_value: Int, source: Int, dest: Int, weight: Float) {
73        let hash = self.hash(m_value, source, dest) as usize;
74        self.buckets[hash] += weight;
75    }
76
77    fn node_count(&self, source: Int) -> Float {
78        self.count(0, source, 0)
79    }
80
81    fn count(&self, m_value: Int, source: Int, dest: Int) -> Float {
82        self.buckets[self.hash(m_value, source, dest) as usize]
83    }
84
85    fn clear(&mut self) {
86        for bucket in self.buckets.iter_mut() {
87            *bucket = 0.;
88        }
89    }
90
91    fn num_buckets(&self) -> usize {
92        self.buckets.len()
93    }
94
95    fn lower(&mut self, alpha: Float) {
96        for bucket in self.buckets.iter_mut() {
97            *bucket = *bucket * alpha;
98        }
99    }
100}
101
102struct Rng(SmallRng);
103
104impl Rng {
105    fn new(seed: Int) -> Self {
106        use rand::SeedableRng;
107        Self(SmallRng::seed_from_u64(seed as u64))
108    }
109
110    fn rand(&mut self) -> Int {
111        use rand::RngCore;
112        self.0.next_u32() as Int
113    }
114}
115
116struct EdgeHash {
117    m_value: Int,
118    rows: Vec<Row>,
119}
120
121impl EdgeHash {
122    fn new(rows: Int, buckets: Int, m_value: Int, seed: Int) -> Self {
123        let mut rng = Rng::new(seed);
124
125        Self {
126            m_value,
127            rows: (0..rows).map(|_| Row::new(buckets, &mut rng)).collect(),
128        }
129    }
130
131    fn lower(&mut self, alpha: Float) {
132        for row in self.rows.iter_mut() {
133            row.lower(alpha);
134        }
135    }
136
137    fn clear(&mut self) {
138        for row in self.rows.iter_mut() {
139            row.clear();
140        }
141    }
142
143    fn insert(&mut self, source: Int, dest: Int, weight: Float) {
144        for row in self.rows.iter_mut() {
145            row.insert(self.m_value, source, dest, weight);
146        }
147    }
148
149    fn count(&self, source: Int, dest: Int) -> Float {
150        self.rows
151            .iter()
152            .map(|row| row.count(self.m_value, source, dest))
153            .fold(FLOAT_MAX, float_min)
154    }
155}
156
157struct NodeHash {
158    rows: Vec<Row>,
159}
160
161impl NodeHash {
162    fn new(rows: Int, buckets: Int, seed: Int) -> Self {
163        let mut rng = Rng::new(seed);
164
165        Self {
166            rows: (0..rows).map(|_| Row::new(buckets, &mut rng)).collect(),
167        }
168    }
169
170    fn count(&self, source: Int) -> Float {
171        self.rows
172            .iter()
173            .map(|row| row.node_count(source))
174            .fold(FLOAT_MAX, float_min)
175    }
176
177    fn lower(&mut self, alpha: Float) {
178        for row in self.rows.iter_mut() {
179            row.lower(alpha);
180        }
181    }
182
183    fn insert(&mut self, source: Int, weight: Float) {
184        for row in self.rows.iter_mut() {
185            row.node_insert(source, weight);
186        }
187    }
188}
189
190fn float_max(a: Float, b: Float) -> Float {
191    if a >= b {
192        a
193    } else {
194        b
195    }
196}
197
198fn float_min(a: Float, b: Float) -> Float {
199    if a <= b {
200        a
201    } else {
202        b
203    }
204}
205
206fn counts_to_anom(total: Float, current: Float, current_time: Int) -> Float {
207    let current_mean = total / current_time as Float;
208    let sqerr = float_max(0., current - current_mean).powi(2);
209    (sqerr / current_mean) + (sqerr / (current_mean * float_max(1., (current_time - 1) as Float)))
210}
211
212pub struct MidasRParams {
213    /// Number of rows of buckets to use for internal Count-Min Sketches
214    pub rows: Int,
215    /// Number of buckets in each rows to use for internal Count-Min Sketches
216    pub buckets: Int,
217    /// Value used internally in determining bucket placement. Might be
218    /// made private in future version.
219    pub m_value: Int,
220    /// Factor used to to decay current values when our inputs signal
221    /// that time has ticked ahead.
222    pub alpha: Float,
223}
224
225impl Default for MidasRParams {
226    fn default() -> Self {
227        Self {
228            rows: default::NUM_ROWS,
229            buckets: default::NUM_BUCKETS,
230            m_value: default::M_VALUE,
231            alpha: default::ALPHA,
232        }
233    }
234}
235
236pub struct MidasR {
237    current_time: Int,
238    alpha: Float,
239
240    current_count: EdgeHash,
241    total_count: EdgeHash,
242
243    source_score: NodeHash,
244    dest_score: NodeHash,
245    source_total: NodeHash,
246    dest_total: NodeHash,
247}
248
249impl MidasR {
250    pub fn new(
251        MidasRParams {
252            rows,
253            buckets,
254            m_value,
255            alpha,
256        }: MidasRParams,
257    ) -> Self {
258        let dumb_seed = 538;
259
260        Self {
261            current_time: 0,
262            alpha,
263
264            current_count: EdgeHash::new(rows, buckets, m_value, dumb_seed + 1),
265            total_count: EdgeHash::new(rows, buckets, m_value, dumb_seed + 2),
266
267            source_score: NodeHash::new(rows, buckets, dumb_seed + 3),
268            dest_score: NodeHash::new(rows, buckets, dumb_seed + 4),
269            source_total: NodeHash::new(rows, buckets, dumb_seed + 5),
270            dest_total: NodeHash::new(rows, buckets, dumb_seed + 6),
271        }
272    }
273
274    pub fn current_time(&self) -> Int {
275        self.current_time
276    }
277
278    /// Factor used to to decay current values when our inputs signal
279    /// that time has ticked ahead.
280    pub fn alpha(&self) -> Float {
281        self.alpha
282    }
283
284    /// # Panics
285    ///
286    /// If `time < self.current_time()`
287    pub fn insert(&mut self, (source, dest, time): (Int, Int, Int)) -> Float {
288        assert!(self.current_time <= time);
289
290        if time > self.current_time {
291            // This deviation from the original C++ implementation is
292            // mentioned at
293            // https://github.com/bhatiasiddharth/MIDAS/issues/7#issuecomment-597185695
294            let time_delta = time - self.current_time;
295            let total_decay = self.alpha.powi(time_delta as _);
296            self.current_count.lower(total_decay);
297            self.source_score.lower(total_decay);
298            self.dest_score.lower(total_decay);
299
300            self.current_time = time;
301        }
302
303        self.current_count.insert(source, dest, 1.);
304        self.total_count.insert(source, dest, 1.);
305
306        self.source_score.insert(source, 1.);
307        self.dest_score.insert(dest, 1.);
308        self.source_total.insert(source, 1.);
309        self.dest_total.insert(dest, 1.);
310
311        self.query(source, dest)
312    }
313
314    pub fn query(&self, source: Int, dest: Int) -> Float {
315        let current_score = counts_to_anom(
316            self.total_count.count(source, dest),
317            self.current_count.count(source, dest),
318            self.current_time,
319        );
320        let current_score_source = counts_to_anom(
321            self.source_total.count(source),
322            self.source_score.count(source),
323            self.current_time,
324        );
325        let current_score_dest = counts_to_anom(
326            self.dest_total.count(dest),
327            self.dest_score.count(dest),
328            self.current_time,
329        );
330
331        float_max(
332            float_max(current_score_source, current_score_dest),
333            current_score,
334        )
335        .ln_1p()
336    }
337
338    /// Takes an iterator of `(source, dest, time)` thruples and returns
339    /// an iterator of corresponding scores.
340    ///
341    /// For a more ergonomic version, see `MidasIterator::midas_r`.
342    ///
343    /// # Panics
344    ///
345    /// Subsequent iterator will panic if ever passed a thruple where
346    /// the third element (the time) decreases from its predecessor.
347    pub fn iterate(
348        data: impl Iterator<Item = (Int, Int, Int)>,
349        params: MidasRParams,
350    ) -> impl Iterator<Item = Float> {
351        let mut midas = Self::new(params);
352
353        data.map(move |datum| midas.insert(datum))
354    }
355}
356
357pub struct MidasParams {
358    /// Number of rows of buckets to use for internal Count-Min Sketches
359    pub rows: Int,
360    /// Number of buckets in each rows to use for internal Count-Min Sketches
361    pub buckets: Int,
362    /// Value used internally in determining bucket placement. Might be
363    /// made private in future version.
364    pub m_value: Int,
365}
366
367impl Default for MidasParams {
368    fn default() -> Self {
369        Self {
370            rows: default::NUM_ROWS,
371            buckets: default::NUM_BUCKETS,
372            m_value: default::M_VALUE,
373        }
374    }
375}
376
377pub struct Midas {
378    current_time: Int,
379    current_count: EdgeHash,
380    total_count: EdgeHash,
381}
382
383impl Midas {
384    pub fn new(
385        MidasParams {
386            rows,
387            buckets,
388            m_value,
389        }: MidasParams,
390    ) -> Self {
391        let dumb_seed = 39;
392
393        Self {
394            current_time: 0,
395            current_count: EdgeHash::new(rows, buckets, m_value, dumb_seed + 1),
396            total_count: EdgeHash::new(rows, buckets, m_value, dumb_seed + 2),
397        }
398    }
399
400    pub fn current_time(&self) -> Int {
401        self.current_time
402    }
403
404    /// # Panics
405    ///
406    /// If `time < self.current_time()`
407    pub fn insert(&mut self, (source, dest, time): (Int, Int, Int)) -> Float {
408        assert!(self.current_time <= time);
409
410        if time > self.current_time {
411            self.current_count.clear();
412            self.current_time = time;
413        }
414
415        self.current_count.insert(source, dest, 1.);
416        self.total_count.insert(source, dest, 1.);
417
418        self.query(source, dest)
419    }
420
421    pub fn query(&self, source: Int, dest: Int) -> Float {
422        let current_mean = self.total_count.count(source, dest) / self.current_time as Float;
423        let sqerr = (self.current_count.count(source, dest) - current_mean).powi(2);
424
425        if self.current_time == 1 {
426            0.
427        } else {
428            (sqerr / current_mean) + (sqerr / (current_mean * (self.current_time - 1) as Float))
429        }
430    }
431
432    /// Takes an iterator of `(source, dest, time)` thruples and returns
433    /// an iterator of corresponding scores.
434    ///
435    /// For a more ergonomic version, see `MidasIterator::midas`.
436    ///
437    /// # Panics
438    ///
439    /// Subsequent iterator will panic if ever passed a thruple where
440    /// the third element (the time) decreases from its predecessor.
441    pub fn iterate(
442        data: impl Iterator<Item = (Int, Int, Int)>,
443        params: MidasParams,
444    ) -> impl Iterator<Item = Float> {
445        let mut midas = Self::new(params);
446
447        data.map(move |datum| midas.insert(datum))
448    }
449}
450
451pub trait MidasIterator<'a>: 'a + Sized + Iterator<Item = (Int, Int, Int)> {
452    /// Takes an iterator of `(source, dest, time)` thruples and returns
453    /// an iterator of corresponding scores.
454    ///
455    /// For a less ergonomic version, see `Midas::iterate`.
456    ///
457    /// # Panics
458    ///
459    /// Subsequent iterator will panic if ever passed a thruple where
460    /// the third element (the time) decreases from its predecessor.
461    fn midas(self, params: MidasParams) -> Box<dyn 'a + Iterator<Item = Float>> {
462        Box::new(Midas::iterate(self, params))
463    }
464
465    fn thing() {
466        let iter = vec![(1, 1, 1), (1, 2, 1), (1, 1, 3), (1, 2, 4)]
467            .into_iter()
468            .midas_r(Default::default());
469
470        for value in iter {
471            println!("{:.6}", value);
472        }
473    }
474
475    /// Takes an iterator of `(source, dest, time)` thruples and returns
476    /// an iterator of corresponding scores.
477    ///
478    /// For a less ergonomic version, see `MidasR::iterate`.
479    ///
480    /// ```rust
481    /// # fn main() {
482    /// use midas_rs::MidasIterator;
483    ///
484    /// let iter = vec![(1, 1, 1), (1, 2, 1), (1, 1, 3), (1, 2, 4)]
485    ///     .into_iter()
486    ///     .midas_r(Default::default());
487    ///
488    /// for value in iter {
489    ///     println!("{:.6}", value);
490    /// }
491    /// # }
492    /// ```
493    ///
494    /// # Panics
495    ///
496    /// Subsequent iterator will panic if ever passed a thruple where
497    /// the third element (the time) decreases from its predecessor.
498    fn midas_r(self, params: MidasRParams) -> Box<dyn 'a + Iterator<Item = Float>> {
499        Box::new(MidasR::iterate(self, params))
500    }
501}
502
503impl<'a, T> MidasIterator<'a> for T where T: 'a + Iterator<Item = (Int, Int, Int)> + Sized {}