arithmetic_eval/arith/
mod.rs

1//! `Arithmetic` trait and its implementations.
2//!
3//! # Traits
4//!
5//! An [`Arithmetic`] defines fallible arithmetic operations on primitive values
6//! of an [`ExecutableModule`], namely, addition, subtraction, multiplication, division,
7//! exponentiation (all binary ops), and negation (a unary op). Any module can be run
8//! with any `Arithmetic` on its primitive values, although some modules are reasonably tied
9//! to a particular arithmetic or a class of arithmetics (e.g., arithmetics on finite fields).
10//!
11//! [`OrdArithmetic`] extends [`Arithmetic`] with a partial comparison operation
12//! (i.e., an analogue to [`PartialOrd`]). This is motivated by the fact that comparisons
13//! may be switched off during parsing, and some `Arithmetic`s do not have well-defined comparisons.
14//!
15//! [`ArithmeticExt`] helps converting an [`Arithmetic`] into an [`OrdArithmetic`].
16//!
17//! # Implementations
18//!
19//! This module defines the following kinds of arithmetics:
20//!
21//! - [`StdArithmetic`] takes all implementations from the corresponding [`ops`](core::ops) traits.
22//!   This means that it's safe to use *provided* the ops are infallible. As a counter-example,
23//!   using [`StdArithmetic`] with built-in integer types (such as `u64`) is usually not a good
24//!   idea since the corresponding ops have failure modes (e.g., division by zero or integer
25//!   overflow).
26//! - [`WrappingArithmetic`] is defined for integer types; it uses wrapping semantics for all ops.
27//! - [`CheckedArithmetic`] is defined for integer types; it uses checked semantics for all ops.
28//! - [`ModularArithmetic`] operates on integers modulo the specified number.
29//!
30//! All defined [`Arithmetic`]s strive to be as generic as possible.
31//!
32//! [`ExecutableModule`]: crate::ExecutableModule
33
34#![allow(renamed_and_removed_lints, clippy::unknown_clippy_lints)]
35// ^ `map_err_ignore` is newer than MSRV, and `clippy::unknown_clippy_lints` is removed
36// since Rust 1.51.
37
38use core::{cmp::Ordering, fmt};
39
40use crate::error::ArithmeticError;
41
42#[cfg(feature = "bigint")]
43mod bigint;
44mod generic;
45mod modular;
46
47pub use self::{
48    generic::{
49        Checked, CheckedArithmetic, CheckedArithmeticKind, NegateOnlyZero, StdArithmetic,
50        Unchecked, WrappingArithmetic,
51    },
52    modular::{DoubleWidth, ModularArithmetic},
53};
54
55/// Encapsulates arithmetic operations on a certain primitive type (or an enum of primitive types).
56///
57/// Unlike operations on built-in integer types, arithmetic operations may be fallible.
58/// Additionally, the arithmetic can have a state. This is used, for example, in
59/// [`ModularArithmetic`], which stores the modulus in the state.
60pub trait Arithmetic<T> {
61    /// Adds two values.
62    ///
63    /// # Errors
64    ///
65    /// Returns an error if the operation is unsuccessful (e.g., on integer overflow).
66    fn add(&self, x: T, y: T) -> Result<T, ArithmeticError>;
67
68    /// Subtracts two values.
69    ///
70    /// # Errors
71    ///
72    /// Returns an error if the operation is unsuccessful (e.g., on integer underflow).
73    fn sub(&self, x: T, y: T) -> Result<T, ArithmeticError>;
74
75    /// Multiplies two values.
76    ///
77    /// # Errors
78    ///
79    /// Returns an error if the operation is unsuccessful (e.g., on integer overflow).
80    fn mul(&self, x: T, y: T) -> Result<T, ArithmeticError>;
81
82    /// Divides two values.
83    ///
84    /// # Errors
85    ///
86    /// Returns an error if the operation is unsuccessful (e.g., if `y` is zero or does
87    /// not have a multiplicative inverse in the case of modular arithmetic).
88    fn div(&self, x: T, y: T) -> Result<T, ArithmeticError>;
89
90    /// Raises `x` to the power of `y`.
91    ///
92    /// # Errors
93    ///
94    /// Returns an error if the operation is unsuccessful (e.g., on integer overflow).
95    fn pow(&self, x: T, y: T) -> Result<T, ArithmeticError>;
96
97    /// Negates a value.
98    ///
99    /// # Errors
100    ///
101    /// Returns an error if the operation is unsuccessful (e.g., on integer overflow).
102    fn neg(&self, x: T) -> Result<T, ArithmeticError>;
103
104    /// Checks if two values are equal. Note that equality can be a non-trivial operation;
105    /// e.g., different numbers may be equal as per modular arithmetic.
106    fn eq(&self, x: &T, y: &T) -> bool;
107}
108
109/// Extends an [`Arithmetic`] with a comparison operation on values.
110pub trait OrdArithmetic<T>: Arithmetic<T> {
111    /// Compares two values. Returns `None` if the numbers are not comparable, or the comparison
112    /// result otherwise.
113    fn partial_cmp(&self, x: &T, y: &T) -> Option<Ordering>;
114}
115
116impl<T> fmt::Debug for dyn OrdArithmetic<T> + '_ {
117    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
118        formatter.debug_tuple("OrdArithmetic").finish()
119    }
120}
121
122/// Wrapper type allowing to extend an [`Arithmetic`] to an [`OrdArithmetic`] implementation.
123///
124/// # Examples
125///
126/// This type can only be constructed via [`ArithmeticExt`] trait. See it for the examples
127/// of usage.
128pub struct FullArithmetic<T, A> {
129    base: A,
130    comparison: fn(&T, &T) -> Option<Ordering>,
131}
132
133impl<T, A: Clone> Clone for FullArithmetic<T, A> {
134    fn clone(&self) -> Self {
135        Self {
136            base: self.base.clone(),
137            comparison: self.comparison,
138        }
139    }
140}
141
142impl<T, A: Copy> Copy for FullArithmetic<T, A> {}
143
144impl<T, A: fmt::Debug> fmt::Debug for FullArithmetic<T, A> {
145    fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
146        formatter
147            .debug_struct("FullArithmetic")
148            .field("base", &self.base)
149            .finish()
150    }
151}
152
153impl<T, A> Arithmetic<T> for FullArithmetic<T, A>
154where
155    A: Arithmetic<T>,
156{
157    #[inline]
158    fn add(&self, x: T, y: T) -> Result<T, ArithmeticError> {
159        self.base.add(x, y)
160    }
161
162    #[inline]
163    fn sub(&self, x: T, y: T) -> Result<T, ArithmeticError> {
164        self.base.sub(x, y)
165    }
166
167    #[inline]
168    fn mul(&self, x: T, y: T) -> Result<T, ArithmeticError> {
169        self.base.mul(x, y)
170    }
171
172    #[inline]
173    fn div(&self, x: T, y: T) -> Result<T, ArithmeticError> {
174        self.base.div(x, y)
175    }
176
177    #[inline]
178    fn pow(&self, x: T, y: T) -> Result<T, ArithmeticError> {
179        self.base.pow(x, y)
180    }
181
182    #[inline]
183    fn neg(&self, x: T) -> Result<T, ArithmeticError> {
184        self.base.neg(x)
185    }
186
187    #[inline]
188    fn eq(&self, x: &T, y: &T) -> bool {
189        self.base.eq(x, y)
190    }
191}
192
193impl<T, A> OrdArithmetic<T> for FullArithmetic<T, A>
194where
195    A: Arithmetic<T>,
196{
197    fn partial_cmp(&self, x: &T, y: &T) -> Option<Ordering> {
198        (self.comparison)(x, y)
199    }
200}
201
202/// Extension trait for [`Arithmetic`] allowing to combine the arithmetic with comparisons.
203///
204/// # Examples
205///
206/// ```
207/// use arithmetic_eval::arith::{ArithmeticExt, ModularArithmetic};
208/// # use arithmetic_eval::{ExecutableModule, Value};
209/// # use arithmetic_parser::grammars::{NumGrammar, Untyped, Parse};
210///
211/// # fn main() -> anyhow::Result<()> {
212/// let base = ModularArithmetic::new(11);
213///
214/// // `ModularArithmetic` requires to define how numbers will be compared -
215/// // and the simplest solution is to not compare them at all.
216/// let program = Untyped::<NumGrammar<u32>>::parse_statements("1 < 3 || 1 >= 3")?;
217/// let module = ExecutableModule::builder("test", &program)?.build();
218/// assert_eq!(
219///     module.with_arithmetic(&base.without_comparisons()).run()?,
220///     Value::Bool(false)
221/// );
222///
223/// // We can compare numbers by their integer value. This can lead
224/// // to pretty confusing results, though.
225/// let bogus_arithmetic = base.with_natural_comparison();
226/// let program = Untyped::<NumGrammar<u32>>::parse_statements(r#"
227///     (x, y, z) = (1, 12, 5);
228///     x == y && x < z && y > z
229/// "#)?;
230/// let module = ExecutableModule::builder("test", &program)?.build();
231/// assert_eq!(
232///     module.with_arithmetic(&bogus_arithmetic).run()?,
233///     Value::Bool(true)
234/// );
235///
236/// // It's possible to fix the situation using a custom comparison function,
237/// // which will compare numbers by their residual class.
238/// let less_bogus_arithmetic = base.with_comparison(|&x: &u32, &y: &u32| {
239///     (x % 11).partial_cmp(&(y % 11))
240/// });
241/// assert_eq!(
242///     module.with_arithmetic(&less_bogus_arithmetic).run()?,
243///     Value::Bool(false)
244/// );
245/// # Ok(())
246/// # }
247/// ```
248pub trait ArithmeticExt<T>: Arithmetic<T> + Sized {
249    /// Combines this arithmetic with a comparison function that assumes any two values are
250    /// incomparable.
251    fn without_comparisons(self) -> FullArithmetic<T, Self> {
252        FullArithmetic {
253            base: self,
254            comparison: |_, _| None,
255        }
256    }
257
258    /// Combines this arithmetic with a comparison function specified by the [`PartialOrd`]
259    /// implementation for `T`.
260    fn with_natural_comparison(self) -> FullArithmetic<T, Self>
261    where
262        T: PartialOrd,
263    {
264        FullArithmetic {
265            base: self,
266            comparison: |x, y| x.partial_cmp(y),
267        }
268    }
269
270    /// Combines this arithmetic with the specified comparison function.
271    fn with_comparison(
272        self,
273        comparison: fn(&T, &T) -> Option<Ordering>,
274    ) -> FullArithmetic<T, Self> {
275        FullArithmetic {
276            base: self,
277            comparison,
278        }
279    }
280}
281
282impl<T, A> ArithmeticExt<T> for A where A: Arithmetic<T> {}