rival/eval/run.rs
1//! Main evaluation loop with adaptive precision tuning.
2
3use itertools::{enumerate, izip};
4
5use crate::eval::{
6 execute,
7 machine::{Discretization, Hint, Machine},
8 profile::Execution,
9};
10use crate::interval::Ival;
11
12impl<D: Discretization> Machine<D> {
13 /// Evaluate the compiled real expressions on an input point
14 /// represented as a slice of intervals.
15 ///
16 /// `args` must be the same length as the `vars` passed to
17 /// [`MachineBuilder::build`](super::machine::MachineBuilder::build). The output is a vector of output
18 /// values of the same length as the `exprs` passed to
19 /// [`MachineBuilder::build`](super::machine::MachineBuilder::build).
20 ///
21 /// `hint` can be provided from a previous call to
22 /// [`Machine::analyze_with_hints`] to speed up evaluation.
23 /// Pass `None` for default behavior.
24 ///
25 /// `max_iterations` sets the maximum number of re-evaluation
26 /// iterations before giving up.
27 ///
28 /// # Errors
29 ///
30 /// Returns [`RivalError::InvalidInput`] if the point is an
31 /// invalid input to at least one of the compiled expressions.
32 /// Returns [`RivalError::Unsamplable`] if Rival is unable to
33 /// evaluate at least one expression.
34 ///
35 /// Note that `apply` will only return `Ok` if it can prove
36 /// that it has correctly-rounded the output. It will only
37 /// return `InvalidInput` if it can prove that at least one
38 /// output expression in the machine throws on the given input.
39 pub fn apply(
40 &mut self,
41 args: &[Ival],
42 hint: Option<&[Hint]>,
43 max_iterations: usize,
44 ) -> Result<Vec<Ival>, RivalError> {
45 self.load_arguments(args);
46 let hint_storage;
47 let hint_slice: &[Hint] = if let Some(h) = hint {
48 h
49 } else {
50 hint_storage = self.default_hint.clone();
51 &hint_storage
52 };
53
54 for iteration in 0..max_iterations {
55 if let Some(results) = self.run_iteration(iteration, hint_slice)? {
56 return Ok(results);
57 }
58 }
59
60 Err(RivalError::Unsamplable)
61 }
62
63 /// Evaluate the machine using the baseline strategy.
64 ///
65 /// The baseline strategy uses a single global precision for all
66 /// instructions, doubling it each iteration. This is simpler but
67 /// less efficient than [`Machine::apply`], which uses adaptive
68 /// per-instruction precision tuning.
69 ///
70 /// Call [`Machine::configure_baseline`] before using this method
71 /// to set up the machine for baseline evaluation.
72 pub fn apply_baseline(
73 &mut self,
74 args: &[Ival],
75 hint: Option<&[Hint]>,
76 ) -> Result<Vec<Ival>, RivalError> {
77 self.load_arguments(args);
78
79 let hint_storage;
80 let hint_slice: &[Hint] = if let Some(h) = hint {
81 h
82 } else {
83 hint_storage = self.default_hint.clone();
84 &hint_storage
85 };
86
87 let start_prec = self.disc.target().saturating_add(10);
88 let mut prec = start_prec;
89 let mut iter: usize = 0;
90
91 loop {
92 self.iteration = iter;
93 self.baseline_adjust(prec);
94 self.run_with_hint(hint_slice);
95
96 match self.collect_outputs()? {
97 Some(outputs) => return Ok(outputs),
98 None => {
99 let next = prec.saturating_mul(2);
100 if next > self.max_precision {
101 return Err(RivalError::Unsamplable);
102 }
103 prec = next;
104 iter = iter.saturating_add(1);
105 }
106 }
107 }
108 }
109
110 /// Analyze an input rectangle using the baseline strategy,
111 /// returning status, next hints, and a convergence flag.
112 ///
113 /// See [`Machine::analyze_with_hints`] for details on the
114 /// return values.
115 pub fn analyze_baseline_with_hints(
116 &mut self,
117 rect: &[Ival],
118 hint: Option<&[Hint]>,
119 ) -> (Ival, Vec<Hint>, bool) {
120 self.load_arguments(rect);
121
122 let tmp;
123 let hint_slice = if let Some(h) = hint {
124 h
125 } else {
126 tmp = self.default_hint.clone();
127 &tmp
128 };
129
130 self.iteration = 0;
131 self.baseline_adjust(self.disc.target().saturating_add(10));
132 self.run_with_hint(hint_slice);
133
134 let (good, _done, bad, stuck) = self.return_flags();
135 let (next_hint, converged) = self.make_hint(hint_slice);
136
137 let status = Ival::bool_interval(bad || stuck, (!good) || stuck);
138 (status, next_hint, converged)
139 }
140
141 /// Analyze a hyper-rectangle using the baseline strategy and
142 /// return only the boolean interval status.
143 ///
144 /// See [`Machine::analyze`] for details on the return value.
145 pub fn analyze_baseline(&mut self, rect: &[Ival]) -> Ival {
146 let (status, _hint, _conv) = self.analyze_baseline_with_hints(rect, None);
147 status
148 }
149
150 /// Run a single iteration with precision tuning and hint-guided evaluation.
151 pub(crate) fn run_iteration(
152 &mut self,
153 iteration: usize,
154 hints: &[Hint],
155 ) -> Result<Option<Vec<Ival>>, RivalError> {
156 assert_eq!(hints.len(), self.instructions.len(), "hint length mismatch");
157 self.iteration = iteration;
158 if self.adjust(hints) {
159 return Err(RivalError::Unsamplable);
160 }
161 self.run_with_hint(hints);
162 self.collect_outputs()
163 }
164
165 /// Analyze an input rectangle using adaptive precision tuning.
166 ///
167 /// Returns a `(status, hints, converged)` tuple:
168 ///
169 /// - `status` is a boolean interval indicating whether a call to
170 /// [`Machine::apply`] with inputs in the supplied `rect` is
171 /// guaranteed to raise an error. If false is returned, there is
172 /// no point calling `apply` with any point in the input range.
173 /// If uncertain, some points may raise errors while others may
174 /// not, though nothing is guaranteed. If true is returned,
175 /// `InvalidInput` will not be raised for any point in the range;
176 /// however, `Unsamplable` may still be raised.
177 ///
178 /// - `hints` is a vector of [`Hint`]s that can be passed to
179 /// subsequent calls to [`Machine::apply`] to skip unnecessary
180 /// computation.
181 ///
182 /// - `converged` indicates whether the analysis has converged.
183 pub fn analyze_with_hints(
184 &mut self,
185 rect: &[Ival],
186 hint: Option<&[Hint]>,
187 ) -> (Ival, Vec<Hint>, bool) {
188 self.load_arguments(rect);
189
190 // Use provided hint or default.
191 let tmp;
192 let hint_slice = if let Some(h) = hint {
193 h
194 } else {
195 tmp = self.default_hint.clone();
196 &tmp
197 };
198
199 // One analysis iteration at sampling iteration 0.
200 self.iteration = 0;
201 self.adjust(hint_slice);
202 self.run_with_hint(hint_slice);
203
204 let (good, _done, bad, stuck) = self.return_flags();
205 let (next_hint, converged) = self.make_hint(hint_slice);
206
207 let status = Ival::bool_interval(bad || stuck, (!good) || stuck);
208 (status, next_hint, converged)
209 }
210
211 /// Analyze a hyper-rectangle and return only the boolean interval status.
212 ///
213 /// Returns a boolean interval which indicates whether a call to
214 /// [`Machine::apply`], with inputs in the supplied `rect`, is
215 /// guaranteed to raise an error.
216 ///
217 /// In other words, if false is returned, there is no point calling
218 /// `apply` with any point in the input range. If uncertain, some
219 /// points in the range may raise errors, while others may not,
220 /// though nothing is guaranteed. If true is returned,
221 /// [`RivalError::InvalidInput`] will not be raised for any point
222 /// in the range. However, [`RivalError::Unsamplable`] may still
223 /// be raised.
224 ///
225 /// The advantage of `analyze` over `apply` is that it applies to
226 /// whole ranges of input points and is much faster.
227 pub fn analyze(&mut self, rect: &[Ival]) -> Ival {
228 let (status, _hint, _conv) = self.analyze_with_hints(rect, None);
229 status
230 }
231
232 /// Load argument intervals into the front of the register file.
233 pub(crate) fn load_arguments(&mut self, args: &[Ival]) {
234 assert_eq!(args.len(), self.arguments.len(), "Argument count mismatch");
235 for (i, arg) in args.iter().cloned().enumerate() {
236 self.registers[i] = arg;
237 }
238 self.bumps = 0;
239 self.bumps_activated = false;
240 self.iteration = 0;
241 self.precisions.fill(0);
242 self.repeats.fill(false);
243 self.output_distance.fill(false);
244 if self.profiling_enabled {
245 self.profiler.reset();
246 }
247 }
248
249 /// Execute instructions once using the supplied precision and hint plan.
250 fn run_with_hint(&mut self, hints: &[Hint]) {
251 // On the first iteration use the initial plan; subsequent iterations use tuned state.
252 let (precisions, repeats) = if self.iteration == 0 {
253 (&self.initial_precisions[..], &self.initial_repeats[..])
254 } else {
255 (&self.precisions[..], &self.repeats[..])
256 };
257
258 for (idx, (instruction, &repeat, &precision, hint)) in
259 enumerate(izip!(&self.instructions, repeats, precisions, hints))
260 {
261 if repeat {
262 continue;
263 }
264 let out_reg = self.instruction_register(idx);
265
266 // Hints can override execution.
267 match hint {
268 Hint::Skip => {}
269 Hint::Execute => {
270 if self.profiling_enabled {
271 let start = std::time::Instant::now();
272 execute::evaluate_instruction(instruction, &mut self.registers, precision);
273 let dt = start.elapsed().as_secs_f64() * 1000.0;
274 let exec = Execution {
275 name: instruction.data.name_static(),
276 number: idx as i32,
277 precision,
278 time_ms: dt,
279 iteration: self.iteration,
280 };
281 self.profiler.record(exec);
282 } else {
283 execute::evaluate_instruction(instruction, &mut self.registers, precision)
284 }
285 }
286 // Path reduction aliasing the output of an instruction to one of its inputs.
287 Hint::Alias(pos) => {
288 if let Some(src_reg) = instruction.data.input_at(*pos as usize)
289 && src_reg != out_reg
290 {
291 let (src, dst) = if src_reg < out_reg {
292 let (left, right) = self.registers.split_at_mut(out_reg);
293 (&left[src_reg], &mut right[0])
294 } else {
295 let (left, right) = self.registers.split_at_mut(src_reg);
296 (&right[0], &mut left[out_reg])
297 };
298 dst.assign_from(src);
299 }
300 }
301 // Use pre-computed boolean value.
302 Hint::KnownBool(value) => {
303 self.registers[out_reg] = Ival::bool_interval(*value, *value);
304 }
305 }
306 }
307 }
308
309 fn baseline_adjust(&mut self, new_prec: u32) {
310 let instruction_count = self.instructions.len();
311 let profiling = self.profiling_enabled;
312 let start_time = if profiling {
313 Some(std::time::Instant::now())
314 } else {
315 None
316 };
317
318 // Baseline uses a single global precision for all instructions.
319 self.precisions.fill(new_prec);
320
321 if self.iteration != 0 {
322 let var_count = self.arguments.len();
323
324 // Determine which instructions can affect outputs (must be executed).
325 let mut useful = vec![false; instruction_count];
326 for &root in &self.outputs {
327 if let Some(idx) = self.register_to_instruction(root) {
328 useful[idx] = true;
329 }
330 }
331
332 for idx in (0..instruction_count).rev() {
333 if !useful[idx] {
334 continue;
335 }
336 let out_reg = self.instruction_register(idx);
337 let reg = &self.registers[out_reg];
338 if reg.lo.immovable && reg.hi.immovable {
339 useful[idx] = false;
340 continue;
341 }
342 self.instructions[idx].for_each_input(|reg| {
343 if reg >= var_count {
344 useful[reg - var_count] = true;
345 }
346 });
347 }
348
349 // Set repeats and update constant precisions.
350 for idx in 0..instruction_count {
351 let is_constant = self.initial_repeats[idx];
352 let best_known = self.best_known_precisions[idx];
353
354 let mut inputs_stable = true;
355 if is_constant {
356 self.instructions[idx].for_each_input(|reg| {
357 if reg >= var_count && !self.repeats[reg - var_count] {
358 inputs_stable = false;
359 }
360 });
361 }
362
363 let no_need_to_reevaluate = is_constant && new_prec <= best_known && inputs_stable;
364 let result_is_exact_already = !useful[idx];
365 let repeat = result_is_exact_already || no_need_to_reevaluate;
366
367 if is_constant && !repeat {
368 self.best_known_precisions[idx] = new_prec;
369 }
370 self.repeats[idx] = repeat;
371 }
372 }
373
374 if profiling && let Some(t0) = start_time {
375 let dt_ms = t0.elapsed().as_secs_f64() * 1000.0;
376 self.profiler.record(Execution {
377 name: "adjust",
378 number: -1,
379 precision: (self.iteration as u32) * 1000,
380 time_ms: dt_ms,
381 iteration: self.iteration,
382 });
383 }
384 }
385
386 /// Gather outputs and translate evaluation state into convergence results.
387 fn collect_outputs(&mut self) -> Result<Option<Vec<Ival>>, RivalError> {
388 let (good, done, bad, stuck) = self.return_flags();
389 let mut outputs = Vec::with_capacity(self.outputs.len());
390
391 for &root in &self.outputs {
392 outputs.push(self.registers[root].clone());
393 }
394
395 if bad {
396 return Err(RivalError::InvalidInput);
397 }
398 if done && good {
399 return Ok(Some(outputs));
400 }
401 if stuck {
402 return Err(RivalError::Unsamplable);
403 }
404
405 Ok(None)
406 }
407
408 /// Compute (good, done, bad, stuck) flags and update output_distance like Racket's rival-machine-return.
409 fn return_flags(&mut self) -> (bool, bool, bool, bool) {
410 let mut good = true;
411 let mut done = true;
412 let mut bad = false;
413 let mut stuck = false;
414
415 for (idx, &root) in self.outputs.iter().enumerate() {
416 let value = &self.registers[root];
417 if value.err.total {
418 bad = true;
419 } else if value.err.partial {
420 good = false;
421 }
422 let lo = self.disc.convert(idx, value.lo.as_float());
423 let hi = self.disc.convert(idx, value.hi.as_float());
424 let dist = self.disc.distance(idx, &lo, &hi);
425 self.output_distance[idx] = dist == 1;
426 if dist != 0 {
427 done = false;
428 if value.lo.immovable && value.hi.immovable {
429 stuck = true;
430 }
431 }
432 }
433
434 (good, done, bad, stuck)
435 }
436}
437
438/// Errors that can occur during [`Machine::apply`].
439///
440/// Note that [`Machine::apply`] will only return a result if it can prove
441/// that it has correctly-rounded the output, and it will only return
442/// [`RivalError::InvalidInput`] if it can prove that at least one of the
443/// output expressions in the machine throws on the given input.
444#[derive(thiserror::Error, Debug)]
445pub enum RivalError {
446 /// The input point is invalid for at least one compiled expression.
447 ///
448 /// For example, taking the square root of a negative number, or
449 /// dividing by zero.
450 #[error("Invalid input for rival machine")]
451 InvalidInput,
452 /// Rival was unable to correctly round the output within the
453 /// configured precision and iteration limits.
454 #[error("Unsamplable input for rival machine")]
455 Unsamplable,
456}