1use crate::{NodeJsMajorVersion, Predictor, errors::*};
2use std::{
3 error::Error,
4 sync::{Arc, Mutex},
5};
6use z3::{Config, Context, SatResult, Solver, ast::*};
7
8pub struct NodePredictor {
9 sequence: Vec<f64>,
10 internal_sequence: Vec<f64>,
11 is_solved: bool,
12 node_js_major_version: NodeJsMajorVersion,
13 conc_state_0: u64,
14 conc_state_1: u64,
15 num_predictions_made: Arc<Mutex<u8>>,
16}
17
18impl Predictor for NodePredictor {
19 fn predict_next(&mut self) -> Result<f64, Box<dyn Error>> {
20 self.increment_prediction_count()?;
21 self.solve_symbolic_state()?;
22 let v = self.xor_shift_128_plus_concrete();
23 let p = self.to_double(v);
24 return Ok(p);
25 }
26}
27
28impl NodePredictor {
29 pub const MAX_NUM_PREDICTIONS: u8 = 64;
30 const SS_0_STR: &str = "sym_state_0";
31 const SS_1_STR: &str = "sym_state_1";
32
33 pub fn new(node_js_major_version: NodeJsMajorVersion, seq: Vec<f64>) -> Self {
34 let len = seq.len() as u8;
35 let mut iseq = seq.clone();
36 iseq.reverse();
37
38 return NodePredictor {
39 internal_sequence: iseq,
40 sequence: seq,
41 node_js_major_version,
42 conc_state_0: 0,
43 conc_state_1: 0,
44 num_predictions_made: Arc::new(Mutex::new(len)),
45 is_solved: false,
46 };
47 }
48
49 #[allow(dead_code)]
50 pub fn sequence(&self) -> &[f64] {
51 return &self.sequence;
52 }
53
54 pub fn predict_next(&mut self) -> Result<f64, Box<dyn Error>> {
56 return <Self as Predictor>::predict_next(self);
57 }
58
59 fn xor_shift_128_plus_concrete(&mut self) -> u64 {
60 let result = self.conc_state_0;
61 let t1 = self.conc_state_0;
62 let mut t0 = self.conc_state_1 ^ (self.conc_state_0 >> 26);
63 t0 ^= self.conc_state_0;
64 t0 ^= (t0 >> 17) ^ (t0 >> 34) ^ (t0 >> 51);
65 t0 ^= (t0 << 23) ^ (t0 << 46);
66 self.conc_state_0 = t0;
67 self.conc_state_1 = t1;
68 return result;
69 }
70
71 fn to_double(&self, value: u64) -> f64 {
72 if self.node_js_major_version as u8 >= 24 {
73 return (value >> 11) as f64 / (1u64 << 53) as f64;
74 }
75 return f64::from_bits((value >> 12) | 0x3FF0000000000000) - 1.0;
76 }
77
78 fn increment_prediction_count(&self) -> Result<(), PredictionLimitError> {
80 let mut c = self.num_predictions_made.lock()?;
81 if *c >= Self::MAX_NUM_PREDICTIONS {
82 return Err(PredictionLimitError);
83 }
84 *c += 1;
85 return Ok(());
86 }
87
88 #[allow(dead_code)]
89 fn reset(&mut self, new_sequence: Vec<f64>) -> Result<(), PredictionLimitError> {
90 let mut c = self.num_predictions_made.lock()?;
91 if *c < Self::MAX_NUM_PREDICTIONS {
92 return Ok(());
93 }
94 *c = new_sequence.len() as u8;
95 self.is_solved = false;
96 self.internal_sequence = new_sequence.to_vec();
97 self.sequence = new_sequence.to_vec();
98 self.internal_sequence.reverse();
99 return Ok(());
100 }
101
102 fn solve_symbolic_state(&mut self) -> Result<(), InitError> {
103 if self.is_solved {
104 return Ok(());
105 }
106
107 let config = Config::new();
108 let context = Context::new(&config);
109 let solver = Solver::new(&context);
110
111 let mut sym_state_0 = BV::new_const(&context, Self::SS_0_STR, 64);
112 let mut sym_state_1 = BV::new_const(&context, Self::SS_1_STR, 64);
113
114 for &observed in &self.internal_sequence {
115 Self::xor_shift_128_plus_symbolic(&context, &mut sym_state_0, &mut sym_state_1);
116 Self::constrain_mantissa(
117 observed,
118 self.node_js_major_version,
119 &context,
120 &solver,
121 &sym_state_0,
122 );
123 }
124
125 if solver.check() != SatResult::Sat {
126 return Err(InitError::Unsat);
127 }
128
129 let model = solver.get_model().ok_or(InitError::MissingModel)?;
130
131 self.conc_state_0 = model
132 .eval(&sym_state_0, true)
133 .ok_or(InitError::EvalFailed(Self::SS_0_STR))?
134 .as_u64()
135 .ok_or(InitError::ConvertFailed(Self::SS_0_STR))?;
136
137 self.conc_state_1 = model
138 .eval(&sym_state_1, true)
139 .ok_or(InitError::EvalFailed(Self::SS_1_STR))?
140 .as_u64()
141 .ok_or(InitError::ConvertFailed(Self::SS_1_STR))?;
142
143 for _ in 0..self.internal_sequence.len() {
144 self.xor_shift_128_plus_concrete();
145 }
146
147 self.is_solved = true;
148 return Ok(());
149 }
150
151 fn xor_shift_128_plus_symbolic<'a>(
153 context: &'a Context,
154 state_0: &mut BV<'a>,
155 state_1: &mut BV<'a>,
156 ) {
157 let state_0_shifted_left = state_0.bvshl(&BV::from_u64(context, 23, 64));
158 let mut s1 = &*state_0 ^ state_0_shifted_left;
159 let s1_shifted_right = s1.bvlshr(&BV::from_u64(context, 17, 64));
160
161 s1 ^= s1_shifted_right;
162 s1 ^= state_1.clone();
163 s1 ^= state_1.bvlshr(&BV::from_u64(context, 26, 64));
164
165 std::mem::swap(state_0, state_1);
166 *state_1 = s1;
167 }
168
169 fn constrain_mantissa(
171 value: f64,
172 nodejs_version: NodeJsMajorVersion,
173 context: &Context,
174 solver: &Solver,
175 state_0: &BV,
176 ) {
177 if nodejs_version as u8 >= 24 {
178 let mantissa = (value * (1u64 << 53) as f64) as u64;
180 solver.assert(
182 &state_0
183 .bvlshr(&BV::from_u64(context, 11, 64))
184 ._eq(&BV::from_u64(context, mantissa, 64)),
185 );
186 } else {
187 let mantissa = f64::to_bits(value + 1.0) & ((1u64 << 52) - 1);
189 solver.assert(
191 &BV::from_u64(context, mantissa, 64)._eq(&state_0.bvlshr(&BV::from_u64(context, 12, 64))),
192 );
193 }
194 }
195}
196
197#[cfg(test)]
198mod tests {
199 mod general {
200 use crate::{NodePredictor, errors::PredictionLimitError};
201 use std::error::Error;
202
203 #[test]
204 fn reset_after_exhaustion() -> Result<(), Box<dyn Error>> {
205 let seq_first = vec![
206 0.777225464783239,
207 0.15637962909874392,
208 0.61479550021439,
209 0.613383431187081,
210 ];
211
212 let exp_first = vec![
213 0.13780690875659396,
214 0.9982326337150321,
215 0.004547103255256535,
216 0.14287124304719512,
217 0.07193734860746803,
218 0.41988043371402806,
219 0.2197922772380051,
220 0.3919840116873258,
221 0.872346223942074,
222 0.8706850288116219,
223 0.15113105207209843,
224 0.6388396452515654,
225 0.49440586365264294,
226 0.6587982725994921,
227 0.18400263468494316,
228 0.662415645160952,
229 0.004233542647695265,
230 0.7850940676778024,
231 0.8718140231245509,
232 0.6789540919039344,
233 0.3903186400622056,
234 0.5518644169835116,
235 0.5827729085540138,
236 0.5554012760270357,
237 0.5233538890694638,
238 0.9581085436854987,
239 0.49105573307668293,
240 0.4887541485622109,
241 0.03580260719438155,
242 0.7486864084447863,
243 0.9442814920321353,
244 0.279500250517147,
245 0.573892252919875,
246 0.35303563579361574,
247 0.49663075416404756,
248 0.3761838996110659,
249 0.01940835807427621,
250 0.048560429750311496,
251 0.12478054659752413,
252 0.8748800514290499,
253 0.5585005650941148,
254 0.861530489078495,
255 0.5288744964943755,
256 0.6986980332092166,
257 0.25771635223672984,
258 0.9727178859177362,
259 0.6867934573316927,
260 0.6970474592601525,
261 0.8035245910646631,
262 0.34589316291057026,
263 0.16026446047340037,
264 0.1871389590142859,
265 0.5065543089345518,
266 0.13565177330674527,
267 0.8171462352178724,
268 0.9132684591493374,
269 0.3537461024035218,
270 0.10449476983306794,
271 0.8400598276661568,
272 0.6256282841337143,
273 0.19469967920827957,
274 ];
275
276 let seq_second = vec![
277 0.1155167115902066,
278 0.2738831377473743,
279 0.475867049008157,
280 0.24131310081058077,
281 ];
282
283 let exp_second = vec![
284 0.5567280997370845,
285 0.09262950949369997,
286 0.9774839147267224,
287 0.07372009723227202,
288 0.8903569034540151,
289 0.2559913027687497,
290 0.9357996349973149,
291 0.10659667352144908,
292 0.34537275726933636,
293 0.23697119929732424,
294 0.1411756579261214,
295 0.4397982843668222,
296 0.9628074927171562,
297 0.15509374502364615,
298 ];
299
300 let seq_first_len = seq_first.len();
301
302 let mut np = NodePredictor::new(crate::NodeJsMajorVersion::V24, seq_first);
303
304 let mut first_predictions = vec![];
305
306 for _ in 0..exp_first.len() {
307 match np.predict_next() {
308 Ok(prediction) => {
309 first_predictions.push(prediction);
310 }
311 Err(e) => {
312 if let Some(_pred_limit_err) = e.downcast_ref::<PredictionLimitError>() {
313 np.reset(seq_second)?;
314 break;
315 } else {
316 return Err(e);
317 }
318 }
319 }
320 }
321
322 assert_eq!(first_predictions.len() + seq_first_len, 64);
323
324 let mut second_preds = vec![];
325 for _ in 0..exp_second.len() {
326 let pred = np.predict_next()?;
327 second_preds.push(pred);
328 }
329
330 assert_eq!(second_preds, exp_second);
331 return Ok(());
332 }
333 }
334
335 mod node_v22 {
336 use crate::NodePredictor;
337 use std::error::Error;
338
339 #[test]
340 fn correctly_predicts_sequence() -> Result<(), Box<dyn Error>> {
341 let node_v22_seq = vec![
342 0.36280726230126614,
343 0.32726837947512855,
344 0.22834780314989023,
345 0.18295517908119385,
346 ];
347 let node_v22_expected = vec![
348 0.8853110028441145,
349 0.14326940888839124,
350 0.035607792006009165,
351 0.6491231376351401,
352 0.3345277284146617,
353 0.42618019812863417,
354 ];
355
356 let mut v8p_node_v22 = NodePredictor::new(crate::NodeJsMajorVersion::V22, node_v22_seq);
357
358 let mut v8_node_v22_predictions = vec![];
359 for _ in 0..node_v22_expected.len() {
360 let prediction = v8p_node_v22.predict_next()?;
361 v8_node_v22_predictions.push(prediction);
362 }
363
364 assert_eq!(v8_node_v22_predictions, node_v22_expected);
365 return Ok(());
366 }
367 }
368
369 mod node_v24 {
370 use std::error::Error;
371
372 #[test]
373 fn correctly_predicts_sequence() -> Result<(), Box<dyn Error>> {
374 let node_v24_seq = vec![
375 0.01800425609760259,
376 0.19267361208155598,
377 0.9892770985784053,
378 0.49553307275603264,
379 0.7362624704291061,
380 ];
381 let node_v24_expected = vec![
382 0.8664993194151147,
383 0.5549329443482626,
384 0.8879559862322086,
385 0.9570142746667122,
386 0.7514661363382521,
387 0.9348208735728415,
388 ];
389
390 let mut v8p_node_v24 =
391 crate::NodePredictor::new(crate::NodeJsMajorVersion::V24, node_v24_seq);
392
393 let mut v8_node_v24_predictions = vec![];
394 for _ in 0..node_v24_expected.len() {
395 let prediction = v8p_node_v24.predict_next()?;
396 v8_node_v24_predictions.push(prediction)
397 }
398
399 assert_eq!(v8_node_v24_predictions, node_v24_expected);
400 return Ok(());
401 }
402 }
403}