Skip to main content

rival/interval/
value.rs

1use rug::{
2    Assign, Float,
3    float::{OrdFloat, Round},
4    ops::AssignRound,
5};
6
7use crate::mpfr::zero;
8use rug::ops::NegAssign;
9
10#[derive(Debug, Clone, PartialEq, Eq, Hash)]
11pub(crate) struct Endpoint {
12    pub(crate) val: OrdFloat,
13    pub(crate) immovable: bool,
14}
15
16/// A standard Rival interval containing two arbitrary-precision endpoints.
17///
18/// A standard interval includes both endpoints. Neither endpoint
19/// is allowed to be NaN. Intervals can be either *real* (with
20/// [`rug::Float`] endpoints) or *boolean* (with endpoints 0 or 1).
21///
22/// # Boolean intervals
23///
24/// In a boolean interval, `false` is considered less than `true`,
25/// yielding three boolean interval values:
26/// - True: `[1, 1]` — constructed with [`Ival::bool_interval(true, true)`](Ival::bool_interval)
27/// - False: `[0, 0]` — constructed with [`Ival::bool_interval(false, false)`](Ival::bool_interval)
28/// - Uncertain: `[0, 1]` — constructed with [`Ival::bool_interval(false, true)`](Ival::bool_interval)
29///
30/// # Error intervals
31///
32/// Sometimes an interval will contain invalid inputs to some function.
33/// For example, `sqrt` is undefined for negative inputs! In cases
34/// like this, Rival's output interval will only consider valid inputs.
35/// Error flags are "sticky": further computations on an interval
36/// will maintain already-set error flags.
37///
38/// # Interval Operations
39///
40/// Rival aims to ensure three properties of all helper functions:
41///
42/// - **Soundness** means output intervals contain any output on inputs drawn
43///   from the input intervals. IEEE-1788 refers to this as the output interval
44///   being *valid*.
45///
46/// - **Refinement** means, moreover, that narrower input intervals lead to
47///   narrower output intervals. Rival's movability flags make this a somewhat
48///   more complicated property than typical.
49///
50/// - **Weak completeness** means, moreover, that Rival returns the narrowest
51///   possible valid interval. IEEE-1788 refers to this as the output interval
52///   being *tight*.
53///
54/// Weak completeness (tightness) is the strongest possible property,
55/// while soundness (validity) is the weakest, with refinement somewhere
56/// in between.
57///
58/// The typical use case for Rival is to recompute a certain expression at
59/// ever higher precision, until the computed interval is narrow enough.
60/// However, interval arithmetic is not complete. For example, due to the
61/// limitations of the underlying MPFR library, it's impossible to compute
62/// `(exp(x) / exp(x))` for large enough values of `x`.
63///
64/// While it's impossible to detect this in all cases, Rival provides
65/// support for *movability flags* that can detect many such instances
66/// automatically. Movability flags are correctly propagated by all of
67/// Rival's supported operations, and are set by functions such as `exp`.
68#[derive(Debug, Clone, PartialEq, Eq, Hash)]
69pub struct Ival {
70    pub(crate) lo: Endpoint,
71    pub(crate) hi: Endpoint,
72    pub(crate) err: ErrorFlags,
73}
74
75/// Flags indicating whether invalid inputs were discarded during computation.
76///
77/// When an interval contains invalid inputs to some function (e.g.,
78/// negative inputs to `sqrt`), these flags record what happened:
79///
80/// - [`partial`](ErrorFlags::partial): at least one invalid input was discarded, but some
81///   valid inputs remain.
82/// - [`total`](ErrorFlags::total): all inputs were invalid.
83///
84/// Error flags are "sticky": further computations on an interval
85/// will maintain already-set error flags.
86#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
87pub struct ErrorFlags {
88    pub(crate) partial: bool,
89    pub(crate) total: bool,
90}
91
92impl Endpoint {
93    pub(crate) fn new(val: OrdFloat, immovable: bool) -> Self {
94        Endpoint { val, immovable }
95    }
96
97    #[inline]
98    pub(crate) fn as_float(&self) -> &Float {
99        self.val.as_float()
100    }
101
102    #[inline]
103    pub(crate) fn as_float_mut(&mut self) -> &mut Float {
104        self.val.as_float_mut()
105    }
106
107    pub(crate) fn endpoint_min2_assign(&mut self, b: Endpoint) {
108        use std::cmp::Ordering;
109        match self.val.cmp(&b.val) {
110            Ordering::Less => (),
111            Ordering::Greater => *self = b,
112            Ordering::Equal => self.immovable |= b.immovable,
113        }
114    }
115
116    pub(crate) fn endpoint_max2_assign(&mut self, b: Endpoint) {
117        use std::cmp::Ordering;
118        match self.val.cmp(&b.val) {
119            Ordering::Greater => (),
120            Ordering::Less => *self = b,
121            Ordering::Equal => self.immovable |= b.immovable,
122        }
123    }
124}
125
126impl Ival {
127    pub(crate) fn new(lo: Endpoint, hi: Endpoint, err: ErrorFlags) -> Self {
128        assert!(lo.as_float().prec() == hi.as_float().prec());
129        Ival { lo, hi, err }
130    }
131
132    /// Returns the low endpoint of this interval.
133    #[inline]
134    pub fn lo(&self) -> &Float {
135        self.lo.as_float()
136    }
137
138    /// Returns the high endpoint of this interval.
139    #[inline]
140    pub fn hi(&self) -> &Float {
141        self.hi.as_float()
142    }
143
144    #[inline]
145    pub fn lo_mut(&mut self) -> &mut Float {
146        self.lo.as_float_mut()
147    }
148
149    #[inline]
150    pub fn hi_mut(&mut self) -> &mut Float {
151        self.hi.as_float_mut()
152    }
153
154    #[inline]
155    pub fn lo_immovable(&self) -> bool {
156        self.lo.immovable
157    }
158
159    #[inline]
160    pub fn hi_immovable(&self) -> bool {
161        self.hi.immovable
162    }
163
164    #[inline]
165    pub fn set_immovable(&mut self, lo: bool, hi: bool) {
166        self.lo.immovable = lo;
167        self.hi.immovable = hi;
168    }
169
170    #[inline]
171    pub fn error_flags(&self) -> ErrorFlags {
172        self.err
173    }
174
175    #[inline]
176    pub fn set_error_flags(&mut self, err: ErrorFlags) {
177        self.err = err;
178    }
179
180    #[inline]
181    pub fn prec(&self) -> u32 {
182        self.lo.as_float().prec()
183    }
184
185    #[inline]
186    pub fn set_prec(&mut self, prec: u32) {
187        self.lo.as_float_mut().set_prec(prec);
188        self.hi.as_float_mut().set_prec(prec);
189    }
190
191    pub(crate) fn max_prec(&self) -> u32 {
192        // Assumed that the lo and high precisions are always the same.
193        // This is ony enforced in Ival::new however.
194        self.lo.as_float().prec()
195    }
196
197    pub(crate) fn neg_inplace(&mut self) {
198        self.lo.as_float_mut().neg_assign();
199        self.hi.as_float_mut().neg_assign();
200        std::mem::swap(&mut self.lo, &mut self.hi);
201    }
202
203    /// Construct an interval from two endpoints.
204    ///
205    /// If either endpoint is NaN, or if `lo == hi` and both
206    /// are infinite, an illegal interval is returned (with error
207    /// flags set). The interval is considered movable.
208    pub fn from_lo_hi(lo: Float, hi: Float) -> Self {
209        let err = if lo.is_nan() || hi.is_nan() || (lo.eq(&hi) && lo.is_infinite()) {
210            ErrorFlags::error()
211        } else {
212            ErrorFlags::none()
213        };
214        Ival {
215            lo: Endpoint::new(OrdFloat::from(lo), false),
216            hi: Endpoint::new(OrdFloat::from(hi), false),
217            err,
218        }
219    }
220
221    /// Construct a boolean interval.
222    ///
223    /// A boolean interval has 2-bit precision endpoints:
224    /// `false` is represented as `0` and `true` as `1`.
225    /// Boolean intervals are always immovable.
226    #[inline]
227    pub fn bool_interval(lo_true: bool, hi_true: bool) -> Self {
228        // 2-bit precision is sufficient for 0/1 endpoints.
229        let to_float = |b: bool| Float::with_val(2, if b { 1 } else { 0 });
230        let (lo, hi) = (to_float(lo_true), to_float(hi_true));
231        let err = if lo.is_nan() || hi.is_nan() || (lo.eq(&hi) && lo.is_infinite()) {
232            ErrorFlags::error()
233        } else {
234            ErrorFlags::none()
235        };
236        Ival {
237            lo: Endpoint::new(OrdFloat::from(lo), true),
238            hi: Endpoint::new(OrdFloat::from(hi), true),
239            err,
240        }
241    }
242
243    pub fn f64_assign(&mut self, value: f64) {
244        self.lo.as_float_mut().assign_round(value, Round::Down);
245        self.hi.as_float_mut().assign_round(value, Round::Up);
246        self.err = ErrorFlags::none();
247    }
248
249    pub fn zero(prec: u32) -> Self {
250        let lo = Float::with_val(prec, 0);
251        let hi = Float::with_val(prec, 0);
252        Ival::new(
253            Endpoint::new(OrdFloat::from(lo), true),
254            Endpoint::new(OrdFloat::from(hi), true),
255            ErrorFlags::none(),
256        )
257    }
258
259    pub(crate) fn assign_from(&mut self, src: &Ival) {
260        // Ensure precision.
261        let src_prec = src.prec();
262        self.lo.as_float_mut().set_prec(src_prec);
263        self.hi.as_float_mut().set_prec(src_prec);
264        // Assign.
265        self.lo.as_float_mut().assign(src.lo.as_float());
266        self.lo.immovable = src.lo.immovable;
267        self.hi.as_float_mut().assign(src.hi.as_float());
268        self.hi.immovable = src.hi.immovable;
269        self.err = src.err;
270    }
271
272    /// Compute the union of this interval with `other`.
273    ///
274    /// Maintains error flags, and movability flags when possible.
275    /// If either interval is totally in error, the other is used
276    /// with its partial error flag set.
277    pub fn union_assign(&mut self, other: Ival) {
278        if self.err.total {
279            self.lo = other.lo;
280            self.hi = other.hi;
281            self.err = other.err;
282            self.err.partial = true;
283            return;
284        }
285
286        if other.err.total {
287            self.err.partial = true;
288            return;
289        }
290
291        self.lo.endpoint_min2_assign(other.lo);
292        self.hi.endpoint_max2_assign(other.hi);
293        self.err = self.err.union_disjoint(&other.err);
294    }
295
296    /// Return Some(false) if interval is exactly [0,0], Some(true) if [1,1], else None.
297    /// Returns None whenever there are error flags present.
298    pub(crate) fn known_bool(&self) -> Option<bool> {
299        if self.err.partial || self.err.total {
300            return None;
301        }
302        let lo = self.lo.as_float();
303        let hi = self.hi.as_float();
304        if lo.is_zero() && hi.is_zero() {
305            Some(false)
306        } else if *lo == 1 && *hi == 1 {
307            Some(true)
308        } else {
309            None
310        }
311    }
312
313    // The following helpers mirror previous clamp logic.
314    pub(crate) fn clamp(&mut self, lo: Float, hi: Float) {
315        let x_lo = self.lo.as_float();
316        let x_hi = self.hi.as_float();
317
318        self.err = ErrorFlags::new(
319            self.err.partial || x_lo < &lo || x_hi > &hi,
320            self.err.total || x_hi < &lo || x_lo > &hi,
321        );
322
323        if lo.is_zero() && x_hi.is_zero() {
324            self.lo.val = OrdFloat::from(zero(self.prec()));
325            self.hi.val = OrdFloat::from(zero(self.prec()));
326        } else {
327            if x_lo < &lo {
328                self.lo.val = OrdFloat::from(lo)
329            }
330
331            if x_hi > &hi {
332                self.hi.val = OrdFloat::from(hi);
333            }
334        }
335    }
336
337    pub(crate) fn clamp_strict(&mut self, lo: Float, hi: Float) {
338        let x_lo = self.lo.as_float();
339        let x_hi = self.hi.as_float();
340
341        self.err = ErrorFlags::new(
342            self.err.partial || x_lo <= &lo || x_hi >= &hi,
343            self.err.total || x_hi <= &lo || x_lo >= &hi,
344        );
345
346        if x_lo < &lo {
347            self.lo.val = OrdFloat::from(lo)
348        }
349
350        if x_hi > &hi {
351            self.hi.val = OrdFloat::from(hi);
352        }
353    }
354
355    /// Split an interval at a point, returning the two halves of that
356    /// interval on either side of the split point.
357    pub fn split_at(&self, val: &Float) -> (Ival, Ival) {
358        let lower = Ival::new(
359            self.lo.clone(),
360            Endpoint::new(OrdFloat::from(val.clone()), self.hi.immovable),
361            self.err,
362        );
363        let upper = Ival::new(
364            Endpoint::new(OrdFloat::from(val.clone()), self.lo.immovable),
365            self.hi.clone(),
366            self.err,
367        );
368        (lower, upper)
369    }
370}
371
372impl ErrorFlags {
373    pub fn new(partial: bool, total: bool) -> Self {
374        ErrorFlags { partial, total }
375    }
376
377    pub fn none() -> Self {
378        ErrorFlags::new(false, false)
379    }
380
381    pub fn error() -> Self {
382        ErrorFlags::new(true, true)
383    }
384
385    #[inline]
386    pub fn partial(&self) -> bool {
387        self.partial
388    }
389
390    #[inline]
391    pub fn total(&self) -> bool {
392        self.total
393    }
394
395    pub(crate) fn union(&self, other: &ErrorFlags) -> ErrorFlags {
396        ErrorFlags::new(self.partial || other.partial, self.total || other.total)
397    }
398
399    pub(crate) fn union_disjoint(&self, other: &ErrorFlags) -> ErrorFlags {
400        ErrorFlags::new(self.partial || other.partial, self.total && other.total)
401    }
402}
403
404#[derive(Debug, Clone, Copy, PartialEq, Eq)]
405pub(crate) enum IvalClass {
406    Pos = 1,
407    Neg = -1,
408    Mix = 0,
409}
410
411pub(crate) fn classify(ival: &Ival, strict: bool) -> IvalClass {
412    let lo = ival.lo.as_float();
413    let hi = ival.hi.as_float();
414    if strict {
415        if *lo > 0.0 {
416            IvalClass::Pos
417        } else if *hi < 0.0 {
418            IvalClass::Neg
419        } else {
420            IvalClass::Mix
421        }
422    } else if *lo >= 0.0 {
423        IvalClass::Pos
424    } else if *hi <= 0.0 {
425        IvalClass::Neg
426    } else {
427        IvalClass::Mix
428    }
429}