1use num_traits::{
4 checked_pow, CheckedAdd, CheckedDiv, CheckedMul, CheckedNeg, CheckedSub, NumOps, One, Pow,
5 Signed, Unsigned, WrappingAdd, WrappingMul, WrappingNeg, WrappingSub, Zero,
6};
7
8use core::{cmp::Ordering, convert::TryFrom, marker::PhantomData, ops};
9
10use crate::{
11 arith::{Arithmetic, OrdArithmetic},
12 error::ArithmeticError,
13};
14
15#[derive(Debug, Clone, Copy, Default)]
22pub struct StdArithmetic;
23
24impl<T> Arithmetic<T> for StdArithmetic
25where
26 T: Clone + NumOps + PartialEq + ops::Neg<Output = T> + Pow<T, Output = T>,
27{
28 #[inline]
29 fn add(&self, x: T, y: T) -> Result<T, ArithmeticError> {
30 Ok(x + y)
31 }
32
33 #[inline]
34 fn sub(&self, x: T, y: T) -> Result<T, ArithmeticError> {
35 Ok(x - y)
36 }
37
38 #[inline]
39 fn mul(&self, x: T, y: T) -> Result<T, ArithmeticError> {
40 Ok(x * y)
41 }
42
43 #[inline]
44 fn div(&self, x: T, y: T) -> Result<T, ArithmeticError> {
45 Ok(x / y)
46 }
47
48 #[inline]
49 fn pow(&self, x: T, y: T) -> Result<T, ArithmeticError> {
50 Ok(x.pow(y))
51 }
52
53 #[inline]
54 fn neg(&self, x: T) -> Result<T, ArithmeticError> {
55 Ok(-x)
56 }
57
58 #[inline]
59 fn eq(&self, x: &T, y: &T) -> bool {
60 *x == *y
61 }
62}
63
64impl<T> OrdArithmetic<T> for StdArithmetic
65where
66 Self: Arithmetic<T>,
67 T: PartialOrd,
68{
69 fn partial_cmp(&self, x: &T, y: &T) -> Option<Ordering> {
70 x.partial_cmp(y)
71 }
72}
73
74#[cfg(all(test, feature = "std"))]
75static_assertions::assert_impl_all!(StdArithmetic: OrdArithmetic<f32>, OrdArithmetic<f64>);
76
77#[cfg(all(test, feature = "complex"))]
78static_assertions::assert_impl_all!(
79 StdArithmetic: Arithmetic<num_complex::Complex32>,
80 Arithmetic<num_complex::Complex64>
81);
82
83pub trait CheckedArithmeticKind<T> {
85 fn checked_neg(value: T) -> Option<T>;
87}
88
89#[derive(Debug)]
99pub struct CheckedArithmetic<Kind = Checked>(PhantomData<Kind>);
100
101impl<Kind> Clone for CheckedArithmetic<Kind> {
102 fn clone(&self) -> Self {
103 Self(self.0)
104 }
105}
106
107impl<Kind> Copy for CheckedArithmetic<Kind> {}
108
109impl<Kind> Default for CheckedArithmetic<Kind> {
110 fn default() -> Self {
111 Self(PhantomData)
112 }
113}
114
115impl<Kind> CheckedArithmetic<Kind> {
116 pub const fn new() -> Self {
118 Self(PhantomData)
119 }
120}
121
122impl<T, Kind> Arithmetic<T> for CheckedArithmetic<Kind>
123where
124 T: Clone + PartialEq + Zero + One + CheckedAdd + CheckedSub + CheckedMul + CheckedDiv,
125 Kind: CheckedArithmeticKind<T>,
126 usize: TryFrom<T>,
127{
128 #[inline]
129 fn add(&self, x: T, y: T) -> Result<T, ArithmeticError> {
130 x.checked_add(&y).ok_or(ArithmeticError::IntegerOverflow)
131 }
132
133 #[inline]
134 fn sub(&self, x: T, y: T) -> Result<T, ArithmeticError> {
135 x.checked_sub(&y).ok_or(ArithmeticError::IntegerOverflow)
136 }
137
138 #[inline]
139 fn mul(&self, x: T, y: T) -> Result<T, ArithmeticError> {
140 x.checked_mul(&y).ok_or(ArithmeticError::IntegerOverflow)
141 }
142
143 #[inline]
144 fn div(&self, x: T, y: T) -> Result<T, ArithmeticError> {
145 x.checked_div(&y).ok_or(ArithmeticError::DivisionByZero)
146 }
147
148 #[inline]
149 #[allow(clippy::map_err_ignore)]
150 fn pow(&self, x: T, y: T) -> Result<T, ArithmeticError> {
151 let exp = usize::try_from(y).map_err(|_| ArithmeticError::InvalidExponent)?;
152 checked_pow(x, exp).ok_or(ArithmeticError::IntegerOverflow)
153 }
154
155 #[inline]
156 fn neg(&self, x: T) -> Result<T, ArithmeticError> {
157 Kind::checked_neg(x).ok_or(ArithmeticError::IntegerOverflow)
158 }
159
160 #[inline]
161 fn eq(&self, x: &T, y: &T) -> bool {
162 *x == *y
163 }
164}
165
166#[derive(Debug)]
169pub struct Checked(());
170
171impl<T: CheckedNeg> CheckedArithmeticKind<T> for Checked {
172 fn checked_neg(value: T) -> Option<T> {
173 value.checked_neg()
174 }
175}
176
177#[derive(Debug)]
179pub struct NegateOnlyZero(());
180
181impl<T: Unsigned + Zero> CheckedArithmeticKind<T> for NegateOnlyZero {
182 fn checked_neg(value: T) -> Option<T> {
183 if value.is_zero() {
184 Some(value)
185 } else {
186 None
187 }
188 }
189}
190
191#[derive(Debug)]
195pub struct Unchecked(());
196
197impl<T: Signed> CheckedArithmeticKind<T> for Unchecked {
198 fn checked_neg(value: T) -> Option<T> {
199 Some(-value)
200 }
201}
202
203impl<T, Kind> OrdArithmetic<T> for CheckedArithmetic<Kind>
204where
205 Self: Arithmetic<T>,
206 T: PartialOrd,
207{
208 #[inline]
209 fn partial_cmp(&self, x: &T, y: &T) -> Option<Ordering> {
210 x.partial_cmp(y)
211 }
212}
213
214#[cfg(test)]
215static_assertions::assert_impl_all!(
216 CheckedArithmetic: OrdArithmetic<u8>,
217 OrdArithmetic<i8>,
218 OrdArithmetic<u16>,
219 OrdArithmetic<i16>,
220 OrdArithmetic<u32>,
221 OrdArithmetic<i32>,
222 OrdArithmetic<u64>,
223 OrdArithmetic<i64>,
224 OrdArithmetic<u128>,
225 OrdArithmetic<i128>
226);
227
228#[derive(Debug, Clone, Copy, Default)]
233pub struct WrappingArithmetic;
234
235impl<T> Arithmetic<T> for WrappingArithmetic
236where
237 T: Copy
238 + PartialEq
239 + Zero
240 + One
241 + WrappingAdd
242 + WrappingSub
243 + WrappingMul
244 + WrappingNeg
245 + ops::Div<T, Output = T>,
246 usize: TryFrom<T>,
247{
248 #[inline]
249 fn add(&self, x: T, y: T) -> Result<T, ArithmeticError> {
250 Ok(x.wrapping_add(&y))
251 }
252
253 #[inline]
254 fn sub(&self, x: T, y: T) -> Result<T, ArithmeticError> {
255 Ok(x.wrapping_sub(&y))
256 }
257
258 #[inline]
259 fn mul(&self, x: T, y: T) -> Result<T, ArithmeticError> {
260 Ok(x.wrapping_mul(&y))
261 }
262
263 #[inline]
264 fn div(&self, x: T, y: T) -> Result<T, ArithmeticError> {
265 if y.is_zero() {
266 Err(ArithmeticError::DivisionByZero)
267 } else if y.wrapping_neg().is_one() {
268 Ok(x.wrapping_neg())
271 } else {
272 Ok(x / y)
273 }
274 }
275
276 #[inline]
277 #[allow(clippy::map_err_ignore)]
278 fn pow(&self, x: T, y: T) -> Result<T, ArithmeticError> {
279 let exp = usize::try_from(y).map_err(|_| ArithmeticError::InvalidExponent)?;
280 Ok(wrapping_exp(x, exp))
281 }
282
283 #[inline]
284 fn neg(&self, x: T) -> Result<T, ArithmeticError> {
285 Ok(x.wrapping_neg())
286 }
287
288 #[inline]
289 fn eq(&self, x: &T, y: &T) -> bool {
290 *x == *y
291 }
292}
293
294impl<T> OrdArithmetic<T> for WrappingArithmetic
295where
296 Self: Arithmetic<T>,
297 T: PartialOrd,
298{
299 #[inline]
300 fn partial_cmp(&self, x: &T, y: &T) -> Option<Ordering> {
301 x.partial_cmp(y)
302 }
303}
304
305fn wrapping_exp<T: Copy + One + WrappingMul>(mut base: T, mut exp: usize) -> T {
308 if exp == 0 {
309 return T::one();
310 }
311
312 while exp & 1 == 0 {
313 base = base.wrapping_mul(&base);
314 exp >>= 1;
315 }
316 if exp == 1 {
317 return base;
318 }
319
320 let mut acc = base;
321 while exp > 1 {
322 exp >>= 1;
323 base = base.wrapping_mul(&base);
324 if exp & 1 == 1 {
325 acc = acc.wrapping_mul(&base);
326 }
327 }
328 acc
329}
330
331#[cfg(test)]
332static_assertions::assert_impl_all!(
333 WrappingArithmetic: OrdArithmetic<u8>,
334 OrdArithmetic<i8>,
335 OrdArithmetic<u16>,
336 OrdArithmetic<i16>,
337 OrdArithmetic<u32>,
338 OrdArithmetic<i32>,
339 OrdArithmetic<u64>,
340 OrdArithmetic<i64>,
341 OrdArithmetic<u128>,
342 OrdArithmetic<i128>
343);