js_randomness_predictor/
node_predictor.rs

1use crate::{NodeJsMajorVersion, Predictor, errors::*};
2use std::{
3  error::Error,
4  sync::{Arc, Mutex},
5};
6use z3::{Config, Context, SatResult, Solver, ast::*};
7
8pub struct NodePredictor {
9  sequence: Vec<f64>,
10  internal_sequence: Vec<f64>,
11  is_solved: bool,
12  node_js_major_version: NodeJsMajorVersion,
13  conc_state_0: u64,
14  conc_state_1: u64,
15  num_predictions_made: Arc<Mutex<u8>>,
16}
17
18impl Predictor for NodePredictor {
19  fn predict_next(&mut self) -> Result<f64, Box<dyn Error>> {
20    self.increment_prediction_count()?;
21    self.solve_symbolic_state()?;
22    let v = self.xor_shift_128_plus_concrete();
23    let p = self.to_double(v);
24    return Ok(p);
25  }
26}
27
28impl NodePredictor {
29  pub const MAX_NUM_PREDICTIONS: u8 = 64;
30  const SS_0_STR: &str = "sym_state_0";
31  const SS_1_STR: &str = "sym_state_1";
32
33  pub fn new(node_js_major_version: NodeJsMajorVersion, seq: Vec<f64>) -> Self {
34    let len = seq.len() as u8;
35    let mut iseq = seq.clone();
36    iseq.reverse();
37
38    return NodePredictor {
39      internal_sequence: iseq,
40      sequence: seq,
41      node_js_major_version,
42      conc_state_0: 0,
43      conc_state_1: 0,
44      num_predictions_made: Arc::new(Mutex::new(len)),
45      is_solved: false,
46    };
47  }
48
49  #[allow(dead_code)]
50  pub fn sequence(&self) -> &[f64] {
51    return &self.sequence;
52  }
53
54  // So consumers don't have to import the Predictor trait as well as the struct.
55  pub fn predict_next(&mut self) -> Result<f64, Box<dyn Error>> {
56    return <Self as Predictor>::predict_next(self);
57  }
58
59  fn xor_shift_128_plus_concrete(&mut self) -> u64 {
60    let result = self.conc_state_0;
61    let t1 = self.conc_state_0;
62    let mut t0 = self.conc_state_1 ^ (self.conc_state_0 >> 26);
63    t0 ^= self.conc_state_0;
64    t0 ^= (t0 >> 17) ^ (t0 >> 34) ^ (t0 >> 51);
65    t0 ^= (t0 << 23) ^ (t0 << 46);
66    self.conc_state_0 = t0;
67    self.conc_state_1 = t1;
68    return result;
69  }
70
71  fn to_double(&self, value: u64) -> f64 {
72    if self.node_js_major_version as u8 >= 24 {
73      return (value >> 11) as f64 / (1u64 << 53) as f64;
74    }
75    return f64::from_bits((value >> 12) | 0x3FF0000000000000) - 1.0;
76  }
77
78  // If our count is below the max, we can increment, otherwise error.
79  fn increment_prediction_count(&self) -> Result<(), PredictionLimitError> {
80    let mut c = self.num_predictions_made.lock()?;
81    if *c >= Self::MAX_NUM_PREDICTIONS {
82      return Err(PredictionLimitError);
83    }
84    *c += 1;
85    return Ok(());
86  }
87
88  #[allow(dead_code)]
89  fn reset(&mut self, new_sequence: Vec<f64>) -> Result<(), PredictionLimitError> {
90    let mut c = self.num_predictions_made.lock()?;
91    if *c < Self::MAX_NUM_PREDICTIONS {
92      return Ok(());
93    }
94    *c = new_sequence.len() as u8;
95    self.is_solved = false;
96    self.internal_sequence = new_sequence.to_vec();
97    self.sequence = new_sequence.to_vec();
98    self.internal_sequence.reverse();
99    return Ok(());
100  }
101
102  fn solve_symbolic_state(&mut self) -> Result<(), InitError> {
103    if self.is_solved {
104      return Ok(());
105    }
106
107    let config = Config::new();
108    let context = Context::new(&config);
109    let solver = Solver::new(&context);
110
111    let mut sym_state_0 = BV::new_const(&context, Self::SS_0_STR, 64);
112    let mut sym_state_1 = BV::new_const(&context, Self::SS_1_STR, 64);
113
114    for &observed in &self.internal_sequence {
115      Self::xor_shift_128_plus_symbolic(&context, &mut sym_state_0, &mut sym_state_1);
116      Self::constrain_mantissa(
117        observed,
118        self.node_js_major_version,
119        &context,
120        &solver,
121        &sym_state_0,
122      );
123    }
124
125    if solver.check() != SatResult::Sat {
126      return Err(InitError::Unsat);
127    }
128
129    let model = solver.get_model().ok_or(InitError::MissingModel)?;
130
131    self.conc_state_0 = model
132      .eval(&sym_state_0, true)
133      .ok_or(InitError::EvalFailed(Self::SS_0_STR))?
134      .as_u64()
135      .ok_or(InitError::ConvertFailed(Self::SS_0_STR))?;
136
137    self.conc_state_1 = model
138      .eval(&sym_state_1, true)
139      .ok_or(InitError::EvalFailed(Self::SS_1_STR))?
140      .as_u64()
141      .ok_or(InitError::ConvertFailed(Self::SS_1_STR))?;
142
143    for _ in 0..self.internal_sequence.len() {
144      self.xor_shift_128_plus_concrete();
145    }
146
147    self.is_solved = true;
148    return Ok(());
149  }
150
151  // Static 'helper' method
152  fn xor_shift_128_plus_symbolic<'a>(
153    context: &'a Context,
154    state_0: &mut BV<'a>,
155    state_1: &mut BV<'a>,
156  ) {
157    let state_0_shifted_left = state_0.bvshl(&BV::from_u64(context, 23, 64));
158    let mut s1 = &*state_0 ^ state_0_shifted_left;
159    let s1_shifted_right = s1.bvlshr(&BV::from_u64(context, 17, 64));
160
161    s1 ^= s1_shifted_right;
162    s1 ^= state_1.clone();
163    s1 ^= state_1.bvlshr(&BV::from_u64(context, 26, 64));
164
165    std::mem::swap(state_0, state_1);
166    *state_1 = s1;
167  }
168
169  // Static 'helper' method
170  fn constrain_mantissa(
171    value: f64,
172    nodejs_version: NodeJsMajorVersion,
173    context: &Context,
174    solver: &Solver,
175    state_0: &BV,
176  ) {
177    if nodejs_version as u8 >= 24 {
178      // Recover mantissa
179      let mantissa = (value * (1u64 << 53) as f64) as u64;
180      // Add mantissa constraint
181      solver.assert(
182        &state_0
183          .bvlshr(&BV::from_u64(context, 11, 64))
184          ._eq(&BV::from_u64(context, mantissa, 64)),
185      );
186    } else {
187      // Recover mantissa
188      let mantissa = f64::to_bits(value + 1.0) & ((1u64 << 52) - 1);
189      // Add mantissa constraint
190      solver.assert(
191        &BV::from_u64(context, mantissa, 64)._eq(&state_0.bvlshr(&BV::from_u64(context, 12, 64))),
192      );
193    }
194  }
195}
196
197#[cfg(test)]
198mod tests {
199  mod general {
200    use crate::{NodePredictor, errors::PredictionLimitError};
201    use std::error::Error;
202
203    #[test]
204    fn reset_after_exhaustion() -> Result<(), Box<dyn Error>> {
205      let seq_first = vec![
206        0.777225464783239,
207        0.15637962909874392,
208        0.61479550021439,
209        0.613383431187081,
210      ];
211
212      let exp_first = vec![
213        0.13780690875659396,
214        0.9982326337150321,
215        0.004547103255256535,
216        0.14287124304719512,
217        0.07193734860746803,
218        0.41988043371402806,
219        0.2197922772380051,
220        0.3919840116873258,
221        0.872346223942074,
222        0.8706850288116219,
223        0.15113105207209843,
224        0.6388396452515654,
225        0.49440586365264294,
226        0.6587982725994921,
227        0.18400263468494316,
228        0.662415645160952,
229        0.004233542647695265,
230        0.7850940676778024,
231        0.8718140231245509,
232        0.6789540919039344,
233        0.3903186400622056,
234        0.5518644169835116,
235        0.5827729085540138,
236        0.5554012760270357,
237        0.5233538890694638,
238        0.9581085436854987,
239        0.49105573307668293,
240        0.4887541485622109,
241        0.03580260719438155,
242        0.7486864084447863,
243        0.9442814920321353,
244        0.279500250517147,
245        0.573892252919875,
246        0.35303563579361574,
247        0.49663075416404756,
248        0.3761838996110659,
249        0.01940835807427621,
250        0.048560429750311496,
251        0.12478054659752413,
252        0.8748800514290499,
253        0.5585005650941148,
254        0.861530489078495,
255        0.5288744964943755,
256        0.6986980332092166,
257        0.25771635223672984,
258        0.9727178859177362,
259        0.6867934573316927,
260        0.6970474592601525,
261        0.8035245910646631,
262        0.34589316291057026,
263        0.16026446047340037,
264        0.1871389590142859,
265        0.5065543089345518,
266        0.13565177330674527,
267        0.8171462352178724,
268        0.9132684591493374,
269        0.3537461024035218,
270        0.10449476983306794,
271        0.8400598276661568,
272        0.6256282841337143,
273        0.19469967920827957,
274      ];
275
276      let seq_second = vec![
277        0.1155167115902066,
278        0.2738831377473743,
279        0.475867049008157,
280        0.24131310081058077,
281      ];
282
283      let exp_second = vec![
284        0.5567280997370845,
285        0.09262950949369997,
286        0.9774839147267224,
287        0.07372009723227202,
288        0.8903569034540151,
289        0.2559913027687497,
290        0.9357996349973149,
291        0.10659667352144908,
292        0.34537275726933636,
293        0.23697119929732424,
294        0.1411756579261214,
295        0.4397982843668222,
296        0.9628074927171562,
297        0.15509374502364615,
298      ];
299
300      let seq_first_len = seq_first.len();
301
302      let mut np = NodePredictor::new(crate::NodeJsMajorVersion::V24, seq_first);
303
304      let mut first_predictions = vec![];
305
306      for _ in 0..exp_first.len() {
307        match np.predict_next() {
308          Ok(prediction) => {
309            first_predictions.push(prediction);
310          }
311          Err(e) => {
312            if let Some(_pred_limit_err) = e.downcast_ref::<PredictionLimitError>() {
313              np.reset(seq_second)?;
314              break;
315            } else {
316              return Err(e);
317            }
318          }
319        }
320      }
321
322      assert_eq!(first_predictions.len() + seq_first_len, 64);
323
324      let mut second_preds = vec![];
325      for _ in 0..exp_second.len() {
326        let pred = np.predict_next()?;
327        second_preds.push(pred);
328      }
329
330      assert_eq!(second_preds, exp_second);
331      return Ok(());
332    }
333  }
334
335  mod node_v22 {
336    use crate::NodePredictor;
337    use std::error::Error;
338
339    #[test]
340    fn correctly_predicts_sequence() -> Result<(), Box<dyn Error>> {
341      let node_v22_seq = vec![
342        0.36280726230126614,
343        0.32726837947512855,
344        0.22834780314989023,
345        0.18295517908119385,
346      ];
347      let node_v22_expected = vec![
348        0.8853110028441145,
349        0.14326940888839124,
350        0.035607792006009165,
351        0.6491231376351401,
352        0.3345277284146617,
353        0.42618019812863417,
354      ];
355
356      let mut v8p_node_v22 = NodePredictor::new(crate::NodeJsMajorVersion::V22, node_v22_seq);
357
358      let mut v8_node_v22_predictions = vec![];
359      for _ in 0..node_v22_expected.len() {
360        let prediction = v8p_node_v22.predict_next()?;
361        v8_node_v22_predictions.push(prediction);
362      }
363
364      assert_eq!(v8_node_v22_predictions, node_v22_expected);
365      return Ok(());
366    }
367  }
368
369  mod node_v24 {
370    use std::error::Error;
371
372    #[test]
373    fn correctly_predicts_sequence() -> Result<(), Box<dyn Error>> {
374      let node_v24_seq = vec![
375        0.01800425609760259,
376        0.19267361208155598,
377        0.9892770985784053,
378        0.49553307275603264,
379        0.7362624704291061,
380      ];
381      let node_v24_expected = vec![
382        0.8664993194151147,
383        0.5549329443482626,
384        0.8879559862322086,
385        0.9570142746667122,
386        0.7514661363382521,
387        0.9348208735728415,
388      ];
389
390      let mut v8p_node_v24 =
391        crate::NodePredictor::new(crate::NodeJsMajorVersion::V24, node_v24_seq);
392
393      let mut v8_node_v24_predictions = vec![];
394      for _ in 0..node_v24_expected.len() {
395        let prediction = v8p_node_v24.predict_next()?;
396        v8_node_v24_predictions.push(prediction)
397      }
398
399      assert_eq!(v8_node_v24_predictions, node_v24_expected);
400      return Ok(());
401    }
402  }
403}