arithmetic_eval/arith/mod.rs
1//! `Arithmetic` trait and its implementations.
2//!
3//! # Traits
4//!
5//! An [`Arithmetic`] defines fallible arithmetic operations on primitive values
6//! of an [`ExecutableModule`], namely, addition, subtraction, multiplication, division,
7//! exponentiation (all binary ops), and negation (a unary op). Any module can be run
8//! with any `Arithmetic` on its primitive values, although some modules are reasonably tied
9//! to a particular arithmetic or a class of arithmetics (e.g., arithmetics on finite fields).
10//!
11//! [`OrdArithmetic`] extends [`Arithmetic`] with a partial comparison operation
12//! (i.e., an analogue to [`PartialOrd`]). This is motivated by the fact that comparisons
13//! may be switched off during parsing, and some `Arithmetic`s do not have well-defined comparisons.
14//!
15//! [`ArithmeticExt`] helps converting an [`Arithmetic`] into an [`OrdArithmetic`].
16//!
17//! # Implementations
18//!
19//! This module defines the following kinds of arithmetics:
20//!
21//! - [`StdArithmetic`] takes all implementations from the corresponding [`ops`](core::ops) traits.
22//! This means that it's safe to use *provided* the ops are infallible. As a counter-example,
23//! using [`StdArithmetic`] with built-in integer types (such as `u64`) is usually not a good
24//! idea since the corresponding ops have failure modes (e.g., division by zero or integer
25//! overflow).
26//! - [`WrappingArithmetic`] is defined for integer types; it uses wrapping semantics for all ops.
27//! - [`CheckedArithmetic`] is defined for integer types; it uses checked semantics for all ops.
28//! - [`ModularArithmetic`] operates on integers modulo the specified number.
29//!
30//! All defined [`Arithmetic`]s strive to be as generic as possible.
31//!
32//! [`ExecutableModule`]: crate::ExecutableModule
33
34#![allow(renamed_and_removed_lints, clippy::unknown_clippy_lints)]
35// ^ `map_err_ignore` is newer than MSRV, and `clippy::unknown_clippy_lints` is removed
36// since Rust 1.51.
37
38use core::{cmp::Ordering, fmt};
39
40use crate::error::ArithmeticError;
41
42#[cfg(feature = "bigint")]
43mod bigint;
44mod generic;
45mod modular;
46
47pub use self::{
48 generic::{
49 Checked, CheckedArithmetic, CheckedArithmeticKind, NegateOnlyZero, StdArithmetic,
50 Unchecked, WrappingArithmetic,
51 },
52 modular::{DoubleWidth, ModularArithmetic},
53};
54
55/// Encapsulates arithmetic operations on a certain primitive type (or an enum of primitive types).
56///
57/// Unlike operations on built-in integer types, arithmetic operations may be fallible.
58/// Additionally, the arithmetic can have a state. This is used, for example, in
59/// [`ModularArithmetic`], which stores the modulus in the state.
60pub trait Arithmetic<T> {
61 /// Adds two values.
62 ///
63 /// # Errors
64 ///
65 /// Returns an error if the operation is unsuccessful (e.g., on integer overflow).
66 fn add(&self, x: T, y: T) -> Result<T, ArithmeticError>;
67
68 /// Subtracts two values.
69 ///
70 /// # Errors
71 ///
72 /// Returns an error if the operation is unsuccessful (e.g., on integer underflow).
73 fn sub(&self, x: T, y: T) -> Result<T, ArithmeticError>;
74
75 /// Multiplies two values.
76 ///
77 /// # Errors
78 ///
79 /// Returns an error if the operation is unsuccessful (e.g., on integer overflow).
80 fn mul(&self, x: T, y: T) -> Result<T, ArithmeticError>;
81
82 /// Divides two values.
83 ///
84 /// # Errors
85 ///
86 /// Returns an error if the operation is unsuccessful (e.g., if `y` is zero or does
87 /// not have a multiplicative inverse in the case of modular arithmetic).
88 fn div(&self, x: T, y: T) -> Result<T, ArithmeticError>;
89
90 /// Raises `x` to the power of `y`.
91 ///
92 /// # Errors
93 ///
94 /// Returns an error if the operation is unsuccessful (e.g., on integer overflow).
95 fn pow(&self, x: T, y: T) -> Result<T, ArithmeticError>;
96
97 /// Negates a value.
98 ///
99 /// # Errors
100 ///
101 /// Returns an error if the operation is unsuccessful (e.g., on integer overflow).
102 fn neg(&self, x: T) -> Result<T, ArithmeticError>;
103
104 /// Checks if two values are equal. Note that equality can be a non-trivial operation;
105 /// e.g., different numbers may be equal as per modular arithmetic.
106 fn eq(&self, x: &T, y: &T) -> bool;
107}
108
109/// Extends an [`Arithmetic`] with a comparison operation on values.
110pub trait OrdArithmetic<T>: Arithmetic<T> {
111 /// Compares two values. Returns `None` if the numbers are not comparable, or the comparison
112 /// result otherwise.
113 fn partial_cmp(&self, x: &T, y: &T) -> Option<Ordering>;
114}
115
116impl<T> fmt::Debug for dyn OrdArithmetic<T> + '_ {
117 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
118 formatter.debug_tuple("OrdArithmetic").finish()
119 }
120}
121
122/// Wrapper type allowing to extend an [`Arithmetic`] to an [`OrdArithmetic`] implementation.
123///
124/// # Examples
125///
126/// This type can only be constructed via [`ArithmeticExt`] trait. See it for the examples
127/// of usage.
128pub struct FullArithmetic<T, A> {
129 base: A,
130 comparison: fn(&T, &T) -> Option<Ordering>,
131}
132
133impl<T, A: Clone> Clone for FullArithmetic<T, A> {
134 fn clone(&self) -> Self {
135 Self {
136 base: self.base.clone(),
137 comparison: self.comparison,
138 }
139 }
140}
141
142impl<T, A: Copy> Copy for FullArithmetic<T, A> {}
143
144impl<T, A: fmt::Debug> fmt::Debug for FullArithmetic<T, A> {
145 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
146 formatter
147 .debug_struct("FullArithmetic")
148 .field("base", &self.base)
149 .finish()
150 }
151}
152
153impl<T, A> Arithmetic<T> for FullArithmetic<T, A>
154where
155 A: Arithmetic<T>,
156{
157 #[inline]
158 fn add(&self, x: T, y: T) -> Result<T, ArithmeticError> {
159 self.base.add(x, y)
160 }
161
162 #[inline]
163 fn sub(&self, x: T, y: T) -> Result<T, ArithmeticError> {
164 self.base.sub(x, y)
165 }
166
167 #[inline]
168 fn mul(&self, x: T, y: T) -> Result<T, ArithmeticError> {
169 self.base.mul(x, y)
170 }
171
172 #[inline]
173 fn div(&self, x: T, y: T) -> Result<T, ArithmeticError> {
174 self.base.div(x, y)
175 }
176
177 #[inline]
178 fn pow(&self, x: T, y: T) -> Result<T, ArithmeticError> {
179 self.base.pow(x, y)
180 }
181
182 #[inline]
183 fn neg(&self, x: T) -> Result<T, ArithmeticError> {
184 self.base.neg(x)
185 }
186
187 #[inline]
188 fn eq(&self, x: &T, y: &T) -> bool {
189 self.base.eq(x, y)
190 }
191}
192
193impl<T, A> OrdArithmetic<T> for FullArithmetic<T, A>
194where
195 A: Arithmetic<T>,
196{
197 fn partial_cmp(&self, x: &T, y: &T) -> Option<Ordering> {
198 (self.comparison)(x, y)
199 }
200}
201
202/// Extension trait for [`Arithmetic`] allowing to combine the arithmetic with comparisons.
203///
204/// # Examples
205///
206/// ```
207/// use arithmetic_eval::arith::{ArithmeticExt, ModularArithmetic};
208/// # use arithmetic_eval::{ExecutableModule, Value};
209/// # use arithmetic_parser::grammars::{NumGrammar, Untyped, Parse};
210///
211/// # fn main() -> anyhow::Result<()> {
212/// let base = ModularArithmetic::new(11);
213///
214/// // `ModularArithmetic` requires to define how numbers will be compared -
215/// // and the simplest solution is to not compare them at all.
216/// let program = Untyped::<NumGrammar<u32>>::parse_statements("1 < 3 || 1 >= 3")?;
217/// let module = ExecutableModule::builder("test", &program)?.build();
218/// assert_eq!(
219/// module.with_arithmetic(&base.without_comparisons()).run()?,
220/// Value::Bool(false)
221/// );
222///
223/// // We can compare numbers by their integer value. This can lead
224/// // to pretty confusing results, though.
225/// let bogus_arithmetic = base.with_natural_comparison();
226/// let program = Untyped::<NumGrammar<u32>>::parse_statements(r#"
227/// (x, y, z) = (1, 12, 5);
228/// x == y && x < z && y > z
229/// "#)?;
230/// let module = ExecutableModule::builder("test", &program)?.build();
231/// assert_eq!(
232/// module.with_arithmetic(&bogus_arithmetic).run()?,
233/// Value::Bool(true)
234/// );
235///
236/// // It's possible to fix the situation using a custom comparison function,
237/// // which will compare numbers by their residual class.
238/// let less_bogus_arithmetic = base.with_comparison(|&x: &u32, &y: &u32| {
239/// (x % 11).partial_cmp(&(y % 11))
240/// });
241/// assert_eq!(
242/// module.with_arithmetic(&less_bogus_arithmetic).run()?,
243/// Value::Bool(false)
244/// );
245/// # Ok(())
246/// # }
247/// ```
248pub trait ArithmeticExt<T>: Arithmetic<T> + Sized {
249 /// Combines this arithmetic with a comparison function that assumes any two values are
250 /// incomparable.
251 fn without_comparisons(self) -> FullArithmetic<T, Self> {
252 FullArithmetic {
253 base: self,
254 comparison: |_, _| None,
255 }
256 }
257
258 /// Combines this arithmetic with a comparison function specified by the [`PartialOrd`]
259 /// implementation for `T`.
260 fn with_natural_comparison(self) -> FullArithmetic<T, Self>
261 where
262 T: PartialOrd,
263 {
264 FullArithmetic {
265 base: self,
266 comparison: |x, y| x.partial_cmp(y),
267 }
268 }
269
270 /// Combines this arithmetic with the specified comparison function.
271 fn with_comparison(
272 self,
273 comparison: fn(&T, &T) -> Option<Ordering>,
274 ) -> FullArithmetic<T, Self> {
275 FullArithmetic {
276 base: self,
277 comparison,
278 }
279 }
280}
281
282impl<T, A> ArithmeticExt<T> for A where A: Arithmetic<T> {}