Skip to main content

rival/eval/
machine.rs

1//! Register machine evaluator.
2
3use std::collections::HashMap;
4
5use super::{
6    ast::Expr,
7    execute,
8    instructions::{Instruction, InstructionData},
9};
10use crate::{
11    eval::{
12        ops,
13        profile::{Execution, Profiler},
14    },
15    interval::Ival,
16};
17use indexmap::IndexMap;
18use rug::Float;
19
20/// A discretization represents some subset of the real numbers
21/// (for example, `f64`).
22pub trait Discretization: Clone {
23    /// The precision in bits needed to exactly represent a value in this subset.
24    fn target(&self) -> u32;
25    /// Convert a bigfloat value to a value in this subset.
26    fn convert(&self, idx: usize, v: &Float) -> Float;
27    /// Determine how close two values in the subset are.
28    ///
29    /// A distance of `0` indicates that the two values are equal.
30    /// A distance of `2` or greater indicates that the two values are far apart.
31    /// A value of exactly `1` indicates that the two values are sequential,
32    /// that is, that they share a rounding boundary.
33    /// This last case triggers special behavior inside Rival
34    /// to handle double-rounding issues.
35    fn distance(&self, idx: usize, lo: &Float, hi: &Float) -> usize;
36}
37
38/// Interval evaluation machine with persistent state and discretization.
39///
40/// A machine is compiled from a list of real-number expressions via
41/// [`MachineBuilder::build`], and can then be evaluated at specific input
42/// points using [`Machine::apply`]. Returns an opaque machine that can
43/// be passed to [`Machine::apply`] to evaluate the compiled real expression
44/// on a specific point.
45///
46/// Internally, a machine converts expressions into a simple register machine.
47/// Compilation is fairly slow, so the ideal use case is to compile a function
48/// once and then apply it to multiple points.
49///
50/// If more than one expression is provided, common subexpressions will be
51/// identified and eliminated during compilation. This makes Rival ideal for
52/// evaluating large families of related expressions, a feature that is
53/// heavily used in [Herbie](https://herbie.uwplse.org). Note that each
54/// expression can use a different discretization.
55pub struct Machine<D: Discretization> {
56    pub(crate) disc: D,
57
58    // Program structure.
59    pub(crate) arguments: Vec<String>,
60    pub(crate) instructions: Vec<Instruction>,
61    pub(crate) outputs: Vec<usize>,
62
63    // Initial state computed during compilation.
64    pub(crate) initial_repeats: Vec<bool>,
65    pub(crate) initial_precisions: Vec<u32>,
66    pub(crate) best_known_precisions: Vec<u32>,
67    pub(crate) default_hint: Vec<Hint>,
68
69    // Runtime state.
70    pub(crate) registers: Vec<Ival>,
71    pub(crate) precisions: Vec<u32>,
72    pub(crate) repeats: Vec<bool>, // true = skip execution (no change needed)
73    pub(crate) output_distance: Vec<bool>, // true = output near discretization boundary
74
75    pub(crate) iteration: usize,
76    pub(crate) bumps: usize, // Number of times bumps mode has been activated
77
78    // Profiling.
79    pub(crate) profiler: Profiler,
80    pub(crate) profiling_enabled: bool,
81
82    // Configuration parameters.
83    pub(crate) max_precision: u32,
84    pub(crate) min_precision: u32,
85    pub(crate) lower_bound_early_stopping: bool,
86    pub(crate) slack_unit: i64,
87    pub(crate) bumps_activated: bool,
88}
89
90/// Hints guide the execution of individual instructions in the machine.
91#[derive(Clone, Debug, PartialEq, Eq)]
92pub enum Hint {
93    /// Instruction executes normally.
94    Execute,
95    /// Instruction is not needed for the outputs being computed.
96    /// Examples include dead code or an untaken branch.
97    Skip,
98    /// Skip execution and copy the value from the chosen input position.
99    /// For example, `if (true) x else y` becomes an alias to `x`.
100    Alias(u8),
101    /// Skip execution because the result is already known exactly.
102    /// For example, `if (x > 0)` where `x = [5, 10]` is always true.
103    KnownBool(bool),
104}
105
106#[derive(Clone, Debug)]
107pub(crate) struct PathOutcome {
108    pub hint: Hint,
109    pub converged: bool,
110}
111
112/// Builder for constructing a [`Machine`] with custom parameters.
113///
114/// # Example
115///
116/// ```
117/// let machine = MachineBuilder::new(my_discretization)
118///     .min_precision(53)
119///     .max_precision(10_000)
120///     .build(exprs, vars);
121/// ```
122pub struct MachineBuilder<D: Discretization> {
123    disc: D,
124    min_precision: u32,
125    max_precision: u32,
126    slack_unit: i64,
127    base_tuning_precision: u32,
128    ampl_tuning_bits: u32,
129    profile_capacity: usize,
130    profiling_enabled: bool,
131}
132
133impl<D: Discretization> MachineBuilder<D> {
134    /// Create a builder with default precision parameters.
135    ///
136    /// Defaults:
137    /// - `min_precision`: 20 bits
138    /// - `max_precision`: 10,000 bits
139    /// - `slack_unit`: 512
140    /// - `profiling`: enabled, buffer capacity 1000
141    pub fn new(disc: D) -> Self {
142        Self {
143            disc,
144            min_precision: 20,
145            max_precision: 10_000,
146            slack_unit: 512,
147            base_tuning_precision: 5,
148            ampl_tuning_bits: 2,
149            profile_capacity: 1000,
150            profiling_enabled: true,
151        }
152    }
153
154    /// Set the minimum working precision in bits.
155    pub fn min_precision(mut self, v: u32) -> Self {
156        self.min_precision = v;
157        self
158    }
159
160    /// Set the maximum working precision in bits.
161    pub fn max_precision(mut self, v: u32) -> Self {
162        self.max_precision = v;
163        self
164    }
165
166    /// Set the slack unit used when computing slack bits.
167    pub fn slack_unit(mut self, v: i64) -> Self {
168        self.slack_unit = v;
169        self
170    }
171
172    /// Set the base tuning precision added to discretization targets.
173    pub fn base_tuning_precision(mut self, v: u32) -> Self {
174        self.base_tuning_precision = v;
175        self
176    }
177
178    /// Set the amplification tuning bits added during propagation.
179    pub fn ampl_tuning_bits(mut self, v: u32) -> Self {
180        self.ampl_tuning_bits = v;
181        self
182    }
183
184    /// Enable or disable per-instruction profiling (enabled by default).
185    pub fn enable_profiling(mut self, enabled: bool) -> Self {
186        self.profiling_enabled = enabled;
187        self
188    }
189
190    /// Set profiling buffer capacity (default 1000 records).
191    pub fn profile_capacity(mut self, cap: usize) -> Self {
192        self.profile_capacity = cap;
193        self
194    }
195
196    /// Compile expressions into a machine.
197    ///
198    /// `exprs` is a list of real-number expressions, using [`Expr`](super::ast::Expr).
199    /// `vars` is a list of the free variables of these expressions.
200    /// An empty `vars` list can be provided if the expressions
201    /// have no free variables.
202    ///
203    /// Returns a [`Machine`], an opaque type that can be passed to
204    /// [`Machine::apply`] to evaluate the compiled real expressions
205    /// on a specific point.
206    pub fn build(self, exprs: Vec<Expr>, vars: Vec<String>) -> Machine<D> {
207        // Optimize and lower expressions to instructions.
208        let optimized_exprs = exprs.into_iter().map(ops::optimize_expr).collect();
209        let (instructions_map, roots) = lower(optimized_exprs, &vars);
210        let var_count = vars.len();
211        let instruction_count = instructions_map.len();
212        let register_count = var_count + instruction_count;
213
214        let mut registers = vec![Ival::zero(self.max_precision); register_count];
215
216        let instructions: Vec<Instruction> = instructions_map
217            .into_iter()
218            .map(|(data, register)| Instruction {
219                out: register,
220                data,
221            })
222            .collect();
223
224        let mut best_known_precisions = vec![0u32; instruction_count];
225        let initial_precisions = make_initial_precisions(
226            &instructions,
227            var_count,
228            &roots,
229            &self.disc,
230            self.base_tuning_precision,
231            self.ampl_tuning_bits,
232        );
233
234        let initial_repeats = make_initial_repeats(
235            &instructions,
236            var_count,
237            &mut registers,
238            &initial_precisions,
239            &mut best_known_precisions,
240        );
241
242        let default_hint = vec![Hint::Execute; instruction_count];
243        let precisions = vec![0u32; instruction_count];
244        let repeats = vec![false; instruction_count];
245        let mut output_distance = vec![false; roots.len()];
246        output_distance.fill(false);
247
248        Machine {
249            disc: self.disc,
250            arguments: vars,
251            instructions,
252            outputs: roots,
253            initial_repeats,
254            initial_precisions,
255            best_known_precisions,
256            default_hint,
257            registers,
258            precisions,
259            repeats,
260            output_distance,
261            iteration: 0,
262            bumps: 0,
263            max_precision: self.max_precision,
264            min_precision: self.min_precision,
265            lower_bound_early_stopping: false,
266            slack_unit: self.slack_unit,
267            bumps_activated: false,
268            profiler: Profiler::with_capacity(self.profile_capacity),
269            profiling_enabled: self.profiling_enabled,
270        }
271    }
272}
273
274impl<D: Discretization> Machine<D> {
275    /// Return the instruction index that writes to the given register when applicable.
276    #[inline]
277    pub(crate) fn register_to_instruction(&self, register: usize) -> Option<usize> {
278        let var_count = self.arguments.len();
279        if register >= var_count {
280            Some(register - var_count)
281        } else {
282            None
283        }
284    }
285
286    /// Return the register index corresponding to an instruction index.
287    #[inline]
288    pub(crate) fn instruction_register(&self, index: usize) -> usize {
289        self.arguments.len() + index
290    }
291
292    /// Return the total number of instructions in the compiled machine.
293    #[inline]
294    pub fn instruction_count(&self) -> usize {
295        self.instructions.len()
296    }
297
298    /// Return the number of input arguments expected by this machine.
299    #[inline]
300    pub fn argument_count(&self) -> usize {
301        self.arguments.len()
302    }
303
304    /// Return the discretization's target precision in bits.
305    #[inline]
306    pub fn target_precision(&self) -> u32 {
307        self.disc.target()
308    }
309
310    /// Return the minimum working precision in bits.
311    #[inline]
312    pub fn min_precision(&self) -> u32 {
313        self.min_precision
314    }
315
316    /// Return the maximum working precision in bits.
317    #[inline]
318    pub fn max_precision(&self) -> u32 {
319        self.max_precision
320    }
321
322    /// Return the precision used for input arguments.
323    #[inline]
324    pub fn argument_precision(&self) -> u32 {
325        self.disc.target().max(self.min_precision)
326    }
327
328    /// Set the maximum working precision in bits.
329    #[inline]
330    pub fn set_max_precision(&mut self, bits: u32) {
331        self.max_precision = bits;
332    }
333
334    /// Return the number of iterations needed for the most recent call to [`Machine::apply`].
335    #[inline]
336    pub fn iterations(&self) -> usize {
337        self.iteration
338    }
339
340    /// Return the number of bumps detected during the most recent call to [`Machine::apply`].
341    #[inline]
342    pub fn bumps(&self) -> usize {
343        self.bumps
344    }
345
346    /// Enable or disable per-instruction profiling.
347    #[inline]
348    pub fn set_profiling_enabled(&mut self, enabled: bool) {
349        self.profiling_enabled = enabled;
350    }
351
352    /// Returns whether per-instruction profiling is enabled.
353    #[inline]
354    pub fn profiling_enabled(&self) -> bool {
355        self.profiling_enabled
356    }
357
358    /// Return a slice of recorded [`Execution`] records from the profiling buffer.
359    #[inline]
360    pub fn execution_records(&self) -> &[Execution] {
361        self.profiler.records()
362    }
363
364    /// Reset the profiling buffer, discarding all recorded executions.
365    #[inline]
366    pub fn clear_executions(&mut self) {
367        self.profiler.reset();
368    }
369
370    /// Return the operation name for each instruction in the machine.
371    pub fn instruction_names(&self) -> Vec<&'static str> {
372        self.instructions
373            .iter()
374            .map(|instr| instr.data.name_static())
375            .collect()
376    }
377
378    /// Reconfigure the machine to use the baseline strategy.
379    ///
380    /// The baseline strategy uses a single global precision for all
381    /// instructions, doubling it each iteration. This is simpler but
382    /// less efficient than the default adaptive precision tuning.
383    pub fn configure_baseline(&mut self) {
384        let var_count = self.arguments.len();
385        let start_prec = self.disc.target().saturating_add(10);
386
387        self.initial_precisions.fill(start_prec);
388        self.best_known_precisions.fill(0);
389
390        self.initial_repeats = make_initial_repeats(
391            &self.instructions,
392            var_count,
393            &mut self.registers,
394            &self.initial_precisions,
395            &mut self.best_known_precisions,
396        );
397    }
398
399    /// Return a snapshot of recorded [`Execution`] records and reset
400    /// the internal buffer pointer.
401    pub fn take_executions(&mut self) -> Vec<Execution> {
402        let slice = self.execution_records().to_vec();
403        self.clear_executions();
404        slice
405    }
406}
407
408impl PathOutcome {
409    /// Create an execute outcome with the given convergence status.
410    #[inline]
411    pub(crate) fn execute(converged: bool) -> PathOutcome {
412        PathOutcome {
413            hint: Hint::Execute,
414            converged,
415        }
416    }
417
418    /// Create an alias outcome for the provided input position.
419    #[inline]
420    pub(crate) fn alias(idx: u8) -> PathOutcome {
421        PathOutcome {
422            hint: Hint::Alias(idx),
423            converged: true,
424        }
425    }
426
427    /// Create a known boolean outcome pinned to the provided value.
428    #[inline]
429    pub(crate) fn known_bool(value: bool) -> PathOutcome {
430        PathOutcome {
431            hint: Hint::KnownBool(value),
432            converged: true,
433        }
434    }
435}
436
437/// Lower optimized expressions into instructions with common subexpression elimination.
438pub(crate) fn lower(
439    exprs: Vec<Expr>,
440    vars: &[String],
441) -> (IndexMap<InstructionData, usize>, Vec<usize>) {
442    let mut current_reg = vars.len();
443    let mut nodes: IndexMap<InstructionData, usize> = IndexMap::new();
444    let var_lookup: HashMap<&str, usize> = vars
445        .iter()
446        .enumerate()
447        .map(|(idx, name)| (name.as_str(), idx))
448        .collect();
449
450    let roots: Vec<usize> = exprs
451        .iter()
452        .map(|expr| ops::lower_expr(expr, &var_lookup, &mut nodes, &mut current_reg))
453        .collect();
454
455    (nodes, roots)
456}
457
458/// Determine initial precision targets for each instruction.
459fn make_initial_precisions<D: Discretization>(
460    instructions: &[Instruction],
461    var_count: usize,
462    roots: &[usize],
463    disc: &D,
464    base_tuning_precision: u32,
465    ampl_tuning_bits: u32,
466) -> Vec<u32> {
467    let mut precisions = vec![0u32; instructions.len()];
468
469    // Initialize output nodes to target + base precision.
470    for &root in roots.iter() {
471        if root >= var_count {
472            precisions[root - var_count] = disc.target() + base_tuning_precision;
473        }
474    }
475
476    // Propagate precisions backward through the computation graph.
477    for idx in (0..instructions.len()).rev() {
478        let current_prec = precisions[idx];
479        instructions[idx].for_each_input(|reg| {
480            if reg >= var_count {
481                let input_idx = reg - var_count;
482                if input_idx != idx {
483                    precisions[input_idx] =
484                        precisions[input_idx].max(current_prec + ampl_tuning_bits);
485                }
486            }
487        });
488    }
489
490    precisions
491}
492
493/// Evaluate and mark constant-only nodes that can skip future execution.
494fn make_initial_repeats(
495    instructions: &[Instruction],
496    var_count: usize,
497    registers: &mut [Ival],
498    initial_precisions: &[u32],
499    best_known_precisions: &mut [u32],
500) -> Vec<bool> {
501    let mut initial_repeats = vec![true; instructions.len()];
502
503    for (idx, (instr, &prec)) in instructions.iter().zip(initial_precisions).enumerate() {
504        let mut depends = false;
505        instr.data.for_each_input(|reg| {
506            let child = reg as isize - var_count as isize;
507            if child == idx as isize {
508                return;
509            }
510            if child < 0 || !initial_repeats[child as usize] {
511                depends = true;
512            }
513        });
514
515        if depends {
516            initial_repeats[idx] = false;
517        } else {
518            execute::evaluate_instruction(instr, registers, prec);
519            best_known_precisions[idx] = prec;
520        }
521    }
522
523    initial_repeats
524}