1use rand::rngs::SmallRng;
21
22pub mod default {
23 use super::{Float, Int};
24
25 pub const NUM_ROWS: Int = 2;
26 pub const NUM_BUCKETS: Int = 769;
27 pub const M_VALUE: Int = 773;
28 pub const ALPHA: Float = 0.6;
29}
30
31pub type Int = u64;
32pub type Float = f64;
33const FLOAT_MAX: Float = std::f64::MAX;
34
35struct Row {
36 a: Int,
37 b: Int,
38 buckets: Vec<Float>,
39}
40
41impl Row {
42 fn new(buckets: Int, rng: &mut Rng) -> Self {
43 Self {
44 a: (rng.rand() % (buckets - 1)) + 1,
45 b: rng.rand() % buckets,
46 buckets: vec![0.; buckets as usize],
47 }
48 }
49
50 fn hash(&self, m_value: Int, source: Int, dest: Int) -> Int {
51 #![allow(unused_comparisons)]
52
53 let resid = m_value
54 .wrapping_mul(dest)
55 .wrapping_add(source)
56 .wrapping_mul(self.a)
57 .wrapping_add(self.b)
58 % self.num_buckets() as Int;
59
60 resid
61 + if resid < 0 {
62 self.num_buckets() as Int
63 } else {
64 0
65 }
66 }
67
68 fn node_insert(&mut self, a: Int, weight: Float) {
69 self.insert(0, a, 0, weight)
70 }
71
72 fn insert(&mut self, m_value: Int, source: Int, dest: Int, weight: Float) {
73 let hash = self.hash(m_value, source, dest) as usize;
74 self.buckets[hash] += weight;
75 }
76
77 fn node_count(&self, source: Int) -> Float {
78 self.count(0, source, 0)
79 }
80
81 fn count(&self, m_value: Int, source: Int, dest: Int) -> Float {
82 self.buckets[self.hash(m_value, source, dest) as usize]
83 }
84
85 fn clear(&mut self) {
86 for bucket in self.buckets.iter_mut() {
87 *bucket = 0.;
88 }
89 }
90
91 fn num_buckets(&self) -> usize {
92 self.buckets.len()
93 }
94
95 fn lower(&mut self, alpha: Float) {
96 for bucket in self.buckets.iter_mut() {
97 *bucket = *bucket * alpha;
98 }
99 }
100}
101
102struct Rng(SmallRng);
103
104impl Rng {
105 fn new(seed: Int) -> Self {
106 use rand::SeedableRng;
107 Self(SmallRng::seed_from_u64(seed as u64))
108 }
109
110 fn rand(&mut self) -> Int {
111 use rand::RngCore;
112 self.0.next_u32() as Int
113 }
114}
115
116struct EdgeHash {
117 m_value: Int,
118 rows: Vec<Row>,
119}
120
121impl EdgeHash {
122 fn new(rows: Int, buckets: Int, m_value: Int, seed: Int) -> Self {
123 let mut rng = Rng::new(seed);
124
125 Self {
126 m_value,
127 rows: (0..rows).map(|_| Row::new(buckets, &mut rng)).collect(),
128 }
129 }
130
131 fn lower(&mut self, alpha: Float) {
132 for row in self.rows.iter_mut() {
133 row.lower(alpha);
134 }
135 }
136
137 fn clear(&mut self) {
138 for row in self.rows.iter_mut() {
139 row.clear();
140 }
141 }
142
143 fn insert(&mut self, source: Int, dest: Int, weight: Float) {
144 for row in self.rows.iter_mut() {
145 row.insert(self.m_value, source, dest, weight);
146 }
147 }
148
149 fn count(&self, source: Int, dest: Int) -> Float {
150 self.rows
151 .iter()
152 .map(|row| row.count(self.m_value, source, dest))
153 .fold(FLOAT_MAX, float_min)
154 }
155}
156
157struct NodeHash {
158 rows: Vec<Row>,
159}
160
161impl NodeHash {
162 fn new(rows: Int, buckets: Int, seed: Int) -> Self {
163 let mut rng = Rng::new(seed);
164
165 Self {
166 rows: (0..rows).map(|_| Row::new(buckets, &mut rng)).collect(),
167 }
168 }
169
170 fn count(&self, source: Int) -> Float {
171 self.rows
172 .iter()
173 .map(|row| row.node_count(source))
174 .fold(FLOAT_MAX, float_min)
175 }
176
177 fn lower(&mut self, alpha: Float) {
178 for row in self.rows.iter_mut() {
179 row.lower(alpha);
180 }
181 }
182
183 fn insert(&mut self, source: Int, weight: Float) {
184 for row in self.rows.iter_mut() {
185 row.node_insert(source, weight);
186 }
187 }
188}
189
190fn float_max(a: Float, b: Float) -> Float {
191 if a >= b {
192 a
193 } else {
194 b
195 }
196}
197
198fn float_min(a: Float, b: Float) -> Float {
199 if a <= b {
200 a
201 } else {
202 b
203 }
204}
205
206fn counts_to_anom(total: Float, current: Float, current_time: Int) -> Float {
207 let current_mean = total / current_time as Float;
208 let sqerr = float_max(0., current - current_mean).powi(2);
209 (sqerr / current_mean) + (sqerr / (current_mean * float_max(1., (current_time - 1) as Float)))
210}
211
212pub struct MidasRParams {
213 pub rows: Int,
215 pub buckets: Int,
217 pub m_value: Int,
220 pub alpha: Float,
223}
224
225impl Default for MidasRParams {
226 fn default() -> Self {
227 Self {
228 rows: default::NUM_ROWS,
229 buckets: default::NUM_BUCKETS,
230 m_value: default::M_VALUE,
231 alpha: default::ALPHA,
232 }
233 }
234}
235
236pub struct MidasR {
237 current_time: Int,
238 alpha: Float,
239
240 current_count: EdgeHash,
241 total_count: EdgeHash,
242
243 source_score: NodeHash,
244 dest_score: NodeHash,
245 source_total: NodeHash,
246 dest_total: NodeHash,
247}
248
249impl MidasR {
250 pub fn new(
251 MidasRParams {
252 rows,
253 buckets,
254 m_value,
255 alpha,
256 }: MidasRParams,
257 ) -> Self {
258 let dumb_seed = 538;
259
260 Self {
261 current_time: 0,
262 alpha,
263
264 current_count: EdgeHash::new(rows, buckets, m_value, dumb_seed + 1),
265 total_count: EdgeHash::new(rows, buckets, m_value, dumb_seed + 2),
266
267 source_score: NodeHash::new(rows, buckets, dumb_seed + 3),
268 dest_score: NodeHash::new(rows, buckets, dumb_seed + 4),
269 source_total: NodeHash::new(rows, buckets, dumb_seed + 5),
270 dest_total: NodeHash::new(rows, buckets, dumb_seed + 6),
271 }
272 }
273
274 pub fn current_time(&self) -> Int {
275 self.current_time
276 }
277
278 pub fn alpha(&self) -> Float {
281 self.alpha
282 }
283
284 pub fn insert(&mut self, (source, dest, time): (Int, Int, Int)) -> Float {
288 assert!(self.current_time <= time);
289
290 if time > self.current_time {
291 let time_delta = time - self.current_time;
295 let total_decay = self.alpha.powi(time_delta as _);
296 self.current_count.lower(total_decay);
297 self.source_score.lower(total_decay);
298 self.dest_score.lower(total_decay);
299
300 self.current_time = time;
301 }
302
303 self.current_count.insert(source, dest, 1.);
304 self.total_count.insert(source, dest, 1.);
305
306 self.source_score.insert(source, 1.);
307 self.dest_score.insert(dest, 1.);
308 self.source_total.insert(source, 1.);
309 self.dest_total.insert(dest, 1.);
310
311 self.query(source, dest)
312 }
313
314 pub fn query(&self, source: Int, dest: Int) -> Float {
315 let current_score = counts_to_anom(
316 self.total_count.count(source, dest),
317 self.current_count.count(source, dest),
318 self.current_time,
319 );
320 let current_score_source = counts_to_anom(
321 self.source_total.count(source),
322 self.source_score.count(source),
323 self.current_time,
324 );
325 let current_score_dest = counts_to_anom(
326 self.dest_total.count(dest),
327 self.dest_score.count(dest),
328 self.current_time,
329 );
330
331 float_max(
332 float_max(current_score_source, current_score_dest),
333 current_score,
334 )
335 .ln_1p()
336 }
337
338 pub fn iterate(
348 data: impl Iterator<Item = (Int, Int, Int)>,
349 params: MidasRParams,
350 ) -> impl Iterator<Item = Float> {
351 let mut midas = Self::new(params);
352
353 data.map(move |datum| midas.insert(datum))
354 }
355}
356
357pub struct MidasParams {
358 pub rows: Int,
360 pub buckets: Int,
362 pub m_value: Int,
365}
366
367impl Default for MidasParams {
368 fn default() -> Self {
369 Self {
370 rows: default::NUM_ROWS,
371 buckets: default::NUM_BUCKETS,
372 m_value: default::M_VALUE,
373 }
374 }
375}
376
377pub struct Midas {
378 current_time: Int,
379 current_count: EdgeHash,
380 total_count: EdgeHash,
381}
382
383impl Midas {
384 pub fn new(
385 MidasParams {
386 rows,
387 buckets,
388 m_value,
389 }: MidasParams,
390 ) -> Self {
391 let dumb_seed = 39;
392
393 Self {
394 current_time: 0,
395 current_count: EdgeHash::new(rows, buckets, m_value, dumb_seed + 1),
396 total_count: EdgeHash::new(rows, buckets, m_value, dumb_seed + 2),
397 }
398 }
399
400 pub fn current_time(&self) -> Int {
401 self.current_time
402 }
403
404 pub fn insert(&mut self, (source, dest, time): (Int, Int, Int)) -> Float {
408 assert!(self.current_time <= time);
409
410 if time > self.current_time {
411 self.current_count.clear();
412 self.current_time = time;
413 }
414
415 self.current_count.insert(source, dest, 1.);
416 self.total_count.insert(source, dest, 1.);
417
418 self.query(source, dest)
419 }
420
421 pub fn query(&self, source: Int, dest: Int) -> Float {
422 let current_mean = self.total_count.count(source, dest) / self.current_time as Float;
423 let sqerr = (self.current_count.count(source, dest) - current_mean).powi(2);
424
425 if self.current_time == 1 {
426 0.
427 } else {
428 (sqerr / current_mean) + (sqerr / (current_mean * (self.current_time - 1) as Float))
429 }
430 }
431
432 pub fn iterate(
442 data: impl Iterator<Item = (Int, Int, Int)>,
443 params: MidasParams,
444 ) -> impl Iterator<Item = Float> {
445 let mut midas = Self::new(params);
446
447 data.map(move |datum| midas.insert(datum))
448 }
449}
450
451pub trait MidasIterator<'a>: 'a + Sized + Iterator<Item = (Int, Int, Int)> {
452 fn midas(self, params: MidasParams) -> Box<dyn 'a + Iterator<Item = Float>> {
462 Box::new(Midas::iterate(self, params))
463 }
464
465 fn thing() {
466 let iter = vec![(1, 1, 1), (1, 2, 1), (1, 1, 3), (1, 2, 4)]
467 .into_iter()
468 .midas_r(Default::default());
469
470 for value in iter {
471 println!("{:.6}", value);
472 }
473 }
474
475 fn midas_r(self, params: MidasRParams) -> Box<dyn 'a + Iterator<Item = Float>> {
499 Box::new(MidasR::iterate(self, params))
500 }
501}
502
503impl<'a, T> MidasIterator<'a> for T where T: 'a + Iterator<Item = (Int, Int, Int)> + Sized {}