Skip to main content

rival/eval/
run.rs

1//! Main evaluation loop with adaptive precision tuning.
2
3use itertools::{enumerate, izip};
4
5use crate::eval::{
6    execute,
7    machine::{Discretization, Hint, Machine},
8    profile::Execution,
9};
10use crate::interval::Ival;
11
12impl<D: Discretization> Machine<D> {
13    /// Evaluate the compiled real expressions on an input point
14    /// represented as a slice of intervals.
15    ///
16    /// `args` must be the same length as the `vars` passed to
17    /// [`MachineBuilder::build`](super::machine::MachineBuilder::build). The output is a vector of output
18    /// values of the same length as the `exprs` passed to
19    /// [`MachineBuilder::build`](super::machine::MachineBuilder::build).
20    ///
21    /// `hint` can be provided from a previous call to
22    /// [`Machine::analyze_with_hints`] to speed up evaluation.
23    /// Pass `None` for default behavior.
24    ///
25    /// `max_iterations` sets the maximum number of re-evaluation
26    /// iterations before giving up.
27    ///
28    /// # Errors
29    ///
30    /// Returns [`RivalError::InvalidInput`] if the point is an
31    /// invalid input to at least one of the compiled expressions.
32    /// Returns [`RivalError::Unsamplable`] if Rival is unable to
33    /// evaluate at least one expression.
34    ///
35    /// Note that `apply` will only return `Ok` if it can prove
36    /// that it has correctly-rounded the output. It will only
37    /// return `InvalidInput` if it can prove that at least one
38    /// output expression in the machine throws on the given input.
39    pub fn apply(
40        &mut self,
41        args: &[Ival],
42        hint: Option<&[Hint]>,
43        max_iterations: usize,
44    ) -> Result<Vec<Ival>, RivalError> {
45        self.load_arguments(args);
46        let hint_storage;
47        let hint_slice: &[Hint] = if let Some(h) = hint {
48            h
49        } else {
50            hint_storage = self.default_hint.clone();
51            &hint_storage
52        };
53
54        for iteration in 0..max_iterations {
55            if let Some(results) = self.run_iteration(iteration, hint_slice)? {
56                return Ok(results);
57            }
58        }
59
60        Err(RivalError::Unsamplable)
61    }
62
63    /// Evaluate the machine using the baseline strategy.
64    ///
65    /// The baseline strategy uses a single global precision for all
66    /// instructions, doubling it each iteration. This is simpler but
67    /// less efficient than [`Machine::apply`], which uses adaptive
68    /// per-instruction precision tuning.
69    ///
70    /// Call [`Machine::configure_baseline`] before using this method
71    /// to set up the machine for baseline evaluation.
72    pub fn apply_baseline(
73        &mut self,
74        args: &[Ival],
75        hint: Option<&[Hint]>,
76    ) -> Result<Vec<Ival>, RivalError> {
77        self.load_arguments(args);
78
79        let hint_storage;
80        let hint_slice: &[Hint] = if let Some(h) = hint {
81            h
82        } else {
83            hint_storage = self.default_hint.clone();
84            &hint_storage
85        };
86
87        let start_prec = self.disc.target().saturating_add(10);
88        let mut prec = start_prec;
89        let mut iter: usize = 0;
90
91        loop {
92            self.iteration = iter;
93            self.baseline_adjust(prec);
94            self.run_with_hint(hint_slice);
95
96            match self.collect_outputs()? {
97                Some(outputs) => return Ok(outputs),
98                None => {
99                    let next = prec.saturating_mul(2);
100                    if next > self.max_precision {
101                        return Err(RivalError::Unsamplable);
102                    }
103                    prec = next;
104                    iter = iter.saturating_add(1);
105                }
106            }
107        }
108    }
109
110    /// Analyze an input rectangle using the baseline strategy,
111    /// returning status, next hints, and a convergence flag.
112    ///
113    /// See [`Machine::analyze_with_hints`] for details on the
114    /// return values.
115    pub fn analyze_baseline_with_hints(
116        &mut self,
117        rect: &[Ival],
118        hint: Option<&[Hint]>,
119    ) -> (Ival, Vec<Hint>, bool) {
120        self.load_arguments(rect);
121
122        let tmp;
123        let hint_slice = if let Some(h) = hint {
124            h
125        } else {
126            tmp = self.default_hint.clone();
127            &tmp
128        };
129
130        self.iteration = 0;
131        self.baseline_adjust(self.disc.target().saturating_add(10));
132        self.run_with_hint(hint_slice);
133
134        let (good, _done, bad, stuck) = self.return_flags();
135        let (next_hint, converged) = self.make_hint(hint_slice);
136
137        let status = Ival::bool_interval(bad || stuck, (!good) || stuck);
138        (status, next_hint, converged)
139    }
140
141    /// Analyze a hyper-rectangle using the baseline strategy and
142    /// return only the boolean interval status.
143    ///
144    /// See [`Machine::analyze`] for details on the return value.
145    pub fn analyze_baseline(&mut self, rect: &[Ival]) -> Ival {
146        let (status, _hint, _conv) = self.analyze_baseline_with_hints(rect, None);
147        status
148    }
149
150    /// Run a single iteration with precision tuning and hint-guided evaluation.
151    pub(crate) fn run_iteration(
152        &mut self,
153        iteration: usize,
154        hints: &[Hint],
155    ) -> Result<Option<Vec<Ival>>, RivalError> {
156        assert_eq!(hints.len(), self.instructions.len(), "hint length mismatch");
157        self.iteration = iteration;
158        if self.adjust(hints) {
159            return Err(RivalError::Unsamplable);
160        }
161        self.run_with_hint(hints);
162        self.collect_outputs()
163    }
164
165    /// Analyze an input rectangle using adaptive precision tuning.
166    ///
167    /// Returns a `(status, hints, converged)` tuple:
168    ///
169    /// - `status` is a boolean interval indicating whether a call to
170    ///   [`Machine::apply`] with inputs in the supplied `rect` is
171    ///   guaranteed to raise an error. If false is returned, there is
172    ///   no point calling `apply` with any point in the input range.
173    ///   If uncertain, some points may raise errors while others may
174    ///   not, though nothing is guaranteed. If true is returned,
175    ///   `InvalidInput` will not be raised for any point in the range;
176    ///   however, `Unsamplable` may still be raised.
177    ///
178    /// - `hints` is a vector of [`Hint`]s that can be passed to
179    ///   subsequent calls to [`Machine::apply`] to skip unnecessary
180    ///   computation.
181    ///
182    /// - `converged` indicates whether the analysis has converged.
183    pub fn analyze_with_hints(
184        &mut self,
185        rect: &[Ival],
186        hint: Option<&[Hint]>,
187    ) -> (Ival, Vec<Hint>, bool) {
188        self.load_arguments(rect);
189
190        // Use provided hint or default.
191        let tmp;
192        let hint_slice = if let Some(h) = hint {
193            h
194        } else {
195            tmp = self.default_hint.clone();
196            &tmp
197        };
198
199        // One analysis iteration at sampling iteration 0.
200        self.iteration = 0;
201        self.adjust(hint_slice);
202        self.run_with_hint(hint_slice);
203
204        let (good, _done, bad, stuck) = self.return_flags();
205        let (next_hint, converged) = self.make_hint(hint_slice);
206
207        let status = Ival::bool_interval(bad || stuck, (!good) || stuck);
208        (status, next_hint, converged)
209    }
210
211    /// Analyze a hyper-rectangle and return only the boolean interval status.
212    ///
213    /// Returns a boolean interval which indicates whether a call to
214    /// [`Machine::apply`], with inputs in the supplied `rect`, is
215    /// guaranteed to raise an error.
216    ///
217    /// In other words, if false is returned, there is no point calling
218    /// `apply` with any point in the input range. If uncertain, some
219    /// points in the range may raise errors, while others may not,
220    /// though nothing is guaranteed. If true is returned,
221    /// [`RivalError::InvalidInput`] will not be raised for any point
222    /// in the range. However, [`RivalError::Unsamplable`] may still
223    /// be raised.
224    ///
225    /// The advantage of `analyze` over `apply` is that it applies to
226    /// whole ranges of input points and is much faster.
227    pub fn analyze(&mut self, rect: &[Ival]) -> Ival {
228        let (status, _hint, _conv) = self.analyze_with_hints(rect, None);
229        status
230    }
231
232    /// Load argument intervals into the front of the register file.
233    pub(crate) fn load_arguments(&mut self, args: &[Ival]) {
234        assert_eq!(args.len(), self.arguments.len(), "Argument count mismatch");
235        for (i, arg) in args.iter().cloned().enumerate() {
236            self.registers[i] = arg;
237        }
238        self.bumps = 0;
239        self.bumps_activated = false;
240        self.iteration = 0;
241        self.precisions.fill(0);
242        self.repeats.fill(false);
243        self.output_distance.fill(false);
244        if self.profiling_enabled {
245            self.profiler.reset();
246        }
247    }
248
249    /// Execute instructions once using the supplied precision and hint plan.
250    fn run_with_hint(&mut self, hints: &[Hint]) {
251        // On the first iteration use the initial plan; subsequent iterations use tuned state.
252        let (precisions, repeats) = if self.iteration == 0 {
253            (&self.initial_precisions[..], &self.initial_repeats[..])
254        } else {
255            (&self.precisions[..], &self.repeats[..])
256        };
257
258        for (idx, (instruction, &repeat, &precision, hint)) in
259            enumerate(izip!(&self.instructions, repeats, precisions, hints))
260        {
261            if repeat {
262                continue;
263            }
264            let out_reg = self.instruction_register(idx);
265
266            // Hints can override execution.
267            match hint {
268                Hint::Skip => {}
269                Hint::Execute => {
270                    if self.profiling_enabled {
271                        let start = std::time::Instant::now();
272                        execute::evaluate_instruction(instruction, &mut self.registers, precision);
273                        let dt = start.elapsed().as_secs_f64() * 1000.0;
274                        let exec = Execution {
275                            name: instruction.data.name_static(),
276                            number: idx as i32,
277                            precision,
278                            time_ms: dt,
279                            iteration: self.iteration,
280                        };
281                        self.profiler.record(exec);
282                    } else {
283                        execute::evaluate_instruction(instruction, &mut self.registers, precision)
284                    }
285                }
286                // Path reduction aliasing the output of an instruction to one of its inputs.
287                Hint::Alias(pos) => {
288                    if let Some(src_reg) = instruction.data.input_at(*pos as usize)
289                        && src_reg != out_reg
290                    {
291                        let (src, dst) = if src_reg < out_reg {
292                            let (left, right) = self.registers.split_at_mut(out_reg);
293                            (&left[src_reg], &mut right[0])
294                        } else {
295                            let (left, right) = self.registers.split_at_mut(src_reg);
296                            (&right[0], &mut left[out_reg])
297                        };
298                        dst.assign_from(src);
299                    }
300                }
301                // Use pre-computed boolean value.
302                Hint::KnownBool(value) => {
303                    self.registers[out_reg] = Ival::bool_interval(*value, *value);
304                }
305            }
306        }
307    }
308
309    fn baseline_adjust(&mut self, new_prec: u32) {
310        let instruction_count = self.instructions.len();
311        let profiling = self.profiling_enabled;
312        let start_time = if profiling {
313            Some(std::time::Instant::now())
314        } else {
315            None
316        };
317
318        // Baseline uses a single global precision for all instructions.
319        self.precisions.fill(new_prec);
320
321        if self.iteration != 0 {
322            let var_count = self.arguments.len();
323
324            // Determine which instructions can affect outputs (must be executed).
325            let mut useful = vec![false; instruction_count];
326            for &root in &self.outputs {
327                if let Some(idx) = self.register_to_instruction(root) {
328                    useful[idx] = true;
329                }
330            }
331
332            for idx in (0..instruction_count).rev() {
333                if !useful[idx] {
334                    continue;
335                }
336                let out_reg = self.instruction_register(idx);
337                let reg = &self.registers[out_reg];
338                if reg.lo.immovable && reg.hi.immovable {
339                    useful[idx] = false;
340                    continue;
341                }
342                self.instructions[idx].for_each_input(|reg| {
343                    if reg >= var_count {
344                        useful[reg - var_count] = true;
345                    }
346                });
347            }
348
349            // Set repeats and update constant precisions.
350            for idx in 0..instruction_count {
351                let is_constant = self.initial_repeats[idx];
352                let best_known = self.best_known_precisions[idx];
353
354                let mut inputs_stable = true;
355                if is_constant {
356                    self.instructions[idx].for_each_input(|reg| {
357                        if reg >= var_count && !self.repeats[reg - var_count] {
358                            inputs_stable = false;
359                        }
360                    });
361                }
362
363                let no_need_to_reevaluate = is_constant && new_prec <= best_known && inputs_stable;
364                let result_is_exact_already = !useful[idx];
365                let repeat = result_is_exact_already || no_need_to_reevaluate;
366
367                if is_constant && !repeat {
368                    self.best_known_precisions[idx] = new_prec;
369                }
370                self.repeats[idx] = repeat;
371            }
372        }
373
374        if profiling && let Some(t0) = start_time {
375            let dt_ms = t0.elapsed().as_secs_f64() * 1000.0;
376            self.profiler.record(Execution {
377                name: "adjust",
378                number: -1,
379                precision: (self.iteration as u32) * 1000,
380                time_ms: dt_ms,
381                iteration: self.iteration,
382            });
383        }
384    }
385
386    /// Gather outputs and translate evaluation state into convergence results.
387    fn collect_outputs(&mut self) -> Result<Option<Vec<Ival>>, RivalError> {
388        let (good, done, bad, stuck) = self.return_flags();
389        let mut outputs = Vec::with_capacity(self.outputs.len());
390
391        for &root in &self.outputs {
392            outputs.push(self.registers[root].clone());
393        }
394
395        if bad {
396            return Err(RivalError::InvalidInput);
397        }
398        if done && good {
399            return Ok(Some(outputs));
400        }
401        if stuck {
402            return Err(RivalError::Unsamplable);
403        }
404
405        Ok(None)
406    }
407
408    /// Compute (good, done, bad, stuck) flags and update output_distance like Racket's rival-machine-return.
409    fn return_flags(&mut self) -> (bool, bool, bool, bool) {
410        let mut good = true;
411        let mut done = true;
412        let mut bad = false;
413        let mut stuck = false;
414
415        for (idx, &root) in self.outputs.iter().enumerate() {
416            let value = &self.registers[root];
417            if value.err.total {
418                bad = true;
419            } else if value.err.partial {
420                good = false;
421            }
422            let lo = self.disc.convert(idx, value.lo.as_float());
423            let hi = self.disc.convert(idx, value.hi.as_float());
424            let dist = self.disc.distance(idx, &lo, &hi);
425            self.output_distance[idx] = dist == 1;
426            if dist != 0 {
427                done = false;
428                if value.lo.immovable && value.hi.immovable {
429                    stuck = true;
430                }
431            }
432        }
433
434        (good, done, bad, stuck)
435    }
436}
437
438/// Errors that can occur during [`Machine::apply`].
439///
440/// Note that [`Machine::apply`] will only return a result if it can prove
441/// that it has correctly-rounded the output, and it will only return
442/// [`RivalError::InvalidInput`] if it can prove that at least one of the
443/// output expressions in the machine throws on the given input.
444#[derive(thiserror::Error, Debug)]
445pub enum RivalError {
446    /// The input point is invalid for at least one compiled expression.
447    ///
448    /// For example, taking the square root of a negative number, or
449    /// dividing by zero.
450    #[error("Invalid input for rival machine")]
451    InvalidInput,
452    /// Rival was unable to correctly round the output within the
453    /// configured precision and iteration limits.
454    #[error("Unsamplable input for rival machine")]
455    Unsamplable,
456}