Skip to main content

fermat_core/
arithmetic.rs

1//! Arithmetic operations: add, sub, mul, div, mul_div, neg, abs.
2//!
3//! # Scale Alignment
4//!
5//! Before addition/subtraction, both operands are scaled to the same number of
6//! decimal places (`max(a.scale, b.scale)`) via `align_scales`. This can itself
7//! overflow if the mantissa is near `i128::MAX`; that is reported as
8//! `ArithmeticError::Overflow`.
9//!
10//! # mul_div Safety
11//!
12//! `checked_mul_div` uses a 256-bit intermediate (`U256`) for `(a × b) / c` so
13//! that the product does not silently wrap. See [`U256`] for the implementation.
14
15use crate::decimal::{Decimal, MAX_SCALE};
16use crate::error::ArithmeticError;
17
18// ─── Powers of 10 table ──────────────────────────────────────────────────────
19
20/// Pre-computed powers of 10 from `10^0` to `10^28`.
21///
22/// Using a lookup instead of `checked_pow` eliminates runtime loops and
23/// guarantees O(1) access — important for CU budget on sBPF.
24pub(crate) const POW10: [i128; 29] = [
25    1,                                      // 10^0
26    10,                                     // 10^1
27    100,                                    // 10^2
28    1_000,                                  // 10^3
29    10_000,                                 // 10^4
30    100_000,                                // 10^5
31    1_000_000,                              // 10^6
32    10_000_000,                             // 10^7
33    100_000_000,                            // 10^8
34    1_000_000_000,                          // 10^9
35    10_000_000_000,                         // 10^10
36    100_000_000_000,                        // 10^11
37    1_000_000_000_000,                      // 10^12
38    10_000_000_000_000,                     // 10^13
39    100_000_000_000_000,                    // 10^14
40    1_000_000_000_000_000,                  // 10^15
41    10_000_000_000_000_000,                 // 10^16
42    100_000_000_000_000_000,                // 10^17
43    1_000_000_000_000_000_000,              // 10^18
44    10_000_000_000_000_000_000,             // 10^19
45    100_000_000_000_000_000_000,            // 10^20
46    1_000_000_000_000_000_000_000,          // 10^21
47    10_000_000_000_000_000_000_000,         // 10^22
48    100_000_000_000_000_000_000_000,        // 10^23
49    1_000_000_000_000_000_000_000_000,      // 10^24
50    10_000_000_000_000_000_000_000_000,     // 10^25
51    100_000_000_000_000_000_000_000_000,    // 10^26
52    1_000_000_000_000_000_000_000_000_000,  // 10^27
53    10_000_000_000_000_000_000_000_000_000, // 10^28
54];
55
56/// Return `10^exp` as `i128` or `Err(ScaleExceeded)` if `exp > MAX_SCALE`.
57#[inline]
58pub(crate) fn pow10(exp: u8) -> Result<i128, ArithmeticError> {
59    POW10
60        .get(exp as usize)
61        .copied()
62        .ok_or(ArithmeticError::ScaleExceeded)
63}
64
65// ─── Scale alignment ─────────────────────────────────────────────────────────
66
67/// Align two operands to a common scale (`max(a.scale, b.scale)`).
68///
69/// Returns `(a_mantissa, b_mantissa, common_scale)`.
70/// Fails with `Overflow` if multiplying to scale up overflows `i128`.
71#[inline]
72pub(crate) fn align_scales(a: Decimal, b: Decimal) -> Result<(i128, i128, u8), ArithmeticError> {
73    use core::cmp::Ordering;
74    match a.scale.cmp(&b.scale) {
75        Ordering::Equal => Ok((a.mantissa, b.mantissa, a.scale)),
76        Ordering::Less => {
77            let diff = b.scale - a.scale;
78            let factor = pow10(diff)?;
79            let scaled = a
80                .mantissa
81                .checked_mul(factor)
82                .ok_or(ArithmeticError::Overflow)?;
83            Ok((scaled, b.mantissa, b.scale))
84        }
85        Ordering::Greater => {
86            let diff = a.scale - b.scale;
87            let factor = pow10(diff)?;
88            let scaled = b
89                .mantissa
90                .checked_mul(factor)
91                .ok_or(ArithmeticError::Overflow)?;
92            Ok((a.mantissa, scaled, a.scale))
93        }
94    }
95}
96
97// ─── Sign helper for mul_div ──────────────────────────────────────────────────
98
99#[derive(Clone, Copy)]
100pub(crate) enum Sign {
101    Positive,
102    Negative,
103    Zero,
104}
105
106/// Compute the sign of `a * b / c` given their `i128` values.
107#[inline]
108pub(crate) fn sign3(a: i128, b: i128, c: i128) -> Sign {
109    if a == 0 || b == 0 {
110        return Sign::Zero;
111    }
112    let neg_a = a < 0;
113    let neg_b = b < 0;
114    let neg_c = c < 0;
115    let negative = (neg_a ^ neg_b) ^ neg_c;
116    if negative {
117        Sign::Negative
118    } else {
119        Sign::Positive
120    }
121}
122
123// ─── 256-bit unsigned integer ─────────────────────────────────────────────────
124
125/// 256-bit unsigned integer represented as two 128-bit limbs.
126///
127/// Used exclusively as an intermediate type in [`Decimal::checked_mul_div`] to
128/// prevent the product `a × b` from silently overflowing `i128` / `u128`.
129///
130/// Layout: `value = hi * 2^128 + lo`
131#[derive(Debug, Clone, Copy, PartialEq, Eq)]
132pub(crate) struct U256 {
133    /// Least-significant 128 bits.
134    pub lo: u128,
135    /// Most-significant 128 bits.
136    pub hi: u128,
137}
138
139impl U256 {
140    #[cfg(test)]
141    pub const ZERO: Self = Self { lo: 0, hi: 0 };
142
143    /// Exact 128 × 128 → 256-bit multiplication using four 64-bit limbs.
144    ///
145    /// ```text
146    /// a = a_hi * 2^64 + a_lo
147    /// b = b_hi * 2^64 + b_lo
148    /// a*b = hh * 2^128 + mid * 2^64 + ll
149    ///     where mid = a_hi*b_lo + a_lo*b_hi
150    /// ```
151    pub fn mul(a: u128, b: u128) -> Self {
152        const MASK64: u128 = u64::MAX as u128;
153        let a_lo = a & MASK64;
154        let a_hi = a >> 64;
155        let b_lo = b & MASK64;
156        let b_hi = b >> 64;
157
158        let ll = a_lo * b_lo;
159        let lh = a_lo * b_hi;
160        let hl = a_hi * b_lo;
161        let hh = a_hi * b_hi;
162
163        let (mid, mid_carry) = lh.overflowing_add(hl);
164        let (lo, lo_carry) = ll.overflowing_add(mid << 64);
165        let hi = hh
166            .wrapping_add(mid >> 64)
167            .wrapping_add(if mid_carry { 1u128 << 64 } else { 0 })
168            .wrapping_add(lo_carry as u128);
169
170        U256 { lo, hi }
171    }
172
173    /// 256-bit / 128-bit → `(quotient: u128, remainder: u128)`.
174    ///
175    /// Returns `None` if `d == 0` or `self.hi >= d` (quotient exceeds `u128`).
176    ///
177    /// ## Algorithm selection
178    ///
179    /// - **hi == 0**: simple 128-bit division (O(1)).
180    /// - **d ≤ u64::MAX**: four-phase 64-bit long-division (O(1), fast path for
181    ///   all realistic financial values where `d` is a token amount or scale factor).
182    /// - **d > u64::MAX**: binary long-division over 256 bits (O(256), always
183    ///   correct; rarely reached in DeFi contexts).
184    pub fn checked_div(self, d: u128) -> Option<(u128, u128)> {
185        if d == 0 {
186            return None;
187        }
188        // Fast path: numerator fits in 128 bits
189        if self.hi == 0 {
190            return Some((self.lo / d, self.lo % d));
191        }
192        // Quotient overflow guard
193        if self.hi >= d {
194            return None;
195        }
196
197        // ── Fast path: d fits in 64 bits ─────────────────────────────────────
198        // The four-phase algorithm computes (r * 2^64 + digit) / d in each phase.
199        // This is safe iff (r * 2^64) doesn't overflow u128, i.e., r < 2^64.
200        // Since r < d and d ≤ 2^64, r < 2^64. ✓
201        if d <= u64::MAX as u128 {
202            const HALF: u128 = 1u128 << 64;
203            const MASK: u128 = HALF - 1;
204
205            let hi_hi = self.hi >> 64;
206            let hi_lo = self.hi & MASK;
207            let lo_hi = self.lo >> 64;
208            let lo_lo = self.lo & MASK;
209
210            let r_a = hi_hi % d;
211            let q_a = hi_hi / d;
212
213            let n_b = r_a * HALF + hi_lo;
214            let q_b = n_b / d;
215            let r_b = n_b % d;
216
217            let n_c = r_b * HALF + lo_hi;
218            let q_c = n_c / d;
219            let r_c = n_c % d;
220
221            let n_d = r_c * HALF + lo_lo;
222            let q_d = n_d / d;
223            let r_d = n_d % d;
224
225            if q_a != 0 || q_b != 0 {
226                return None; // quotient > u128::MAX
227            }
228
229            return Some((q_c * HALF + q_d, r_d));
230        }
231
232        // ── General case: d > 2^64, binary long-division ─────────────────────
233        //
234        // Processes all 256 bits of the numerator from MSB to LSB.
235        // Maintains invariant: r < d at the end of every iteration.
236        //
237        // When `r_hi` (the overflow bit from `r << 1`) is set, the actual
238        // remainder is `2^128 + r_new`, which is guaranteed to be ≥ d and
239        // < 2d (proved from r < d before shift). The wrapping subtraction
240        // `r_new.wrapping_sub(d)` correctly computes `2^128 + r_new - d`.
241        let mut q: u128 = 0;
242        let mut r: u128 = 0;
243
244        for i in (0..256_u32).rev() {
245            let bit: u128 = if i >= 128 {
246                (self.hi >> (i - 128)) & 1
247            } else {
248                (self.lo >> i) & 1
249            };
250
251            let r_hi = r >> 127; // top bit of r (will overflow into bit 128 after shift)
252            let r_new = (r << 1) | bit;
253
254            if r_hi == 1 {
255                // Actual value is 2^128 + r_new; it must be ≥ d (and < 2d).
256                // wrapping_sub gives (2^128 + r_new - d) mod 2^128 = correct result.
257                r = r_new.wrapping_sub(d);
258                if i < 128 {
259                    q |= 1u128 << i;
260                }
261            } else if r_new >= d {
262                r = r_new - d;
263                if i < 128 {
264                    q |= 1u128 << i;
265                }
266            } else {
267                r = r_new;
268            }
269        }
270
271        Some((q, r))
272    }
273}
274
275// ─── Decimal: mul, div, neg, abs, mul_div ────────────────────────────────────
276
277impl Decimal {
278    /// Checked addition. Aligns scales then adds mantissas.
279    pub fn checked_add(self, rhs: Decimal) -> Result<Decimal, ArithmeticError> {
280        let (a, b, scale) = align_scales(self, rhs)?;
281        let mantissa = a.checked_add(b).ok_or(ArithmeticError::Overflow)?;
282        Decimal::new(mantissa, scale)
283    }
284
285    /// Checked subtraction. Aligns scales then subtracts mantissas.
286    pub fn checked_sub(self, rhs: Decimal) -> Result<Decimal, ArithmeticError> {
287        let (a, b, scale) = align_scales(self, rhs)?;
288        let mantissa = a.checked_sub(b).ok_or(ArithmeticError::Overflow)?;
289        Decimal::new(mantissa, scale)
290    }
291
292    /// Checked multiplication: `self * rhs`.
293    ///
294    /// Result scale = `self.scale + rhs.scale`.
295    /// Returns `Err(ScaleExceeded)` if that exceeds `MAX_SCALE`.
296    pub fn checked_mul(self, rhs: Decimal) -> Result<Decimal, ArithmeticError> {
297        let new_scale = self
298            .scale
299            .checked_add(rhs.scale)
300            .filter(|&s| s <= MAX_SCALE)
301            .ok_or(ArithmeticError::ScaleExceeded)?;
302        let mantissa = self
303            .mantissa
304            .checked_mul(rhs.mantissa)
305            .ok_or(ArithmeticError::Overflow)?;
306        Decimal::new(mantissa, new_scale)
307    }
308
309    /// Checked division: `self / rhs`.
310    ///
311    /// Scales the numerator up by `MAX_SCALE - self.scale` places before
312    /// dividing to retain maximum precision.
313    pub fn checked_div(self, rhs: Decimal) -> Result<Decimal, ArithmeticError> {
314        if rhs.mantissa == 0 {
315            return Err(ArithmeticError::DivisionByZero);
316        }
317        let extra = MAX_SCALE.saturating_sub(self.scale);
318        let factor = pow10(extra)?;
319        let scaled_num = self
320            .mantissa
321            .checked_mul(factor)
322            .ok_or(ArithmeticError::Overflow)?;
323        let mantissa = scaled_num
324            .checked_div(rhs.mantissa)
325            .ok_or(ArithmeticError::Overflow)?;
326        let raw_scale = (self.scale as i32) + (extra as i32) - (rhs.scale as i32);
327        if raw_scale < 0 {
328            return Err(ArithmeticError::Underflow);
329        }
330        Decimal::new(mantissa, (raw_scale as u8).min(MAX_SCALE))
331    }
332
333    /// Negation: returns `-self`.
334    ///
335    /// Fails with `Err(Overflow)` for `Decimal::MIN` (two's-complement has no
336    /// positive counterpart for `i128::MIN`).
337    pub fn checked_neg(self) -> Result<Decimal, ArithmeticError> {
338        let mantissa = self
339            .mantissa
340            .checked_neg()
341            .ok_or(ArithmeticError::Overflow)?;
342        Decimal::new(mantissa, self.scale)
343    }
344
345    /// Absolute value: returns `|self|`.
346    ///
347    /// Fails with `Err(Overflow)` for `Decimal::MIN`.
348    pub fn checked_abs(self) -> Result<Decimal, ArithmeticError> {
349        if self.mantissa >= 0 {
350            return Ok(self);
351        }
352        self.checked_neg()
353    }
354
355    /// Compound `(self × numerator) / denominator` with 256-bit intermediate.
356    ///
357    /// Prevents silent overflow that occurs when `self × numerator` exceeds
358    /// `i128::MAX` in a naive two-step `mul` then `div`.
359    pub fn checked_mul_div(
360        self,
361        numerator: Decimal,
362        denominator: Decimal,
363    ) -> Result<Decimal, ArithmeticError> {
364        if denominator.mantissa == 0 {
365            return Err(ArithmeticError::DivisionByZero);
366        }
367
368        let sign = sign3(self.mantissa, numerator.mantissa, denominator.mantissa);
369
370        let a = self.mantissa.unsigned_abs();
371        let b = numerator.mantissa.unsigned_abs();
372        let c = denominator.mantissa.unsigned_abs();
373
374        let product = U256::mul(a, b);
375        let (quotient_u128, _rem) = product.checked_div(c).ok_or(ArithmeticError::Overflow)?;
376
377        let mantissa_abs = i128::try_from(quotient_u128).map_err(|_| ArithmeticError::Overflow)?;
378
379        let signed_mantissa = match sign {
380            Sign::Zero => 0i128,
381            Sign::Positive => mantissa_abs,
382            Sign::Negative => mantissa_abs
383                .checked_neg()
384                .ok_or(ArithmeticError::Overflow)?,
385        };
386
387        let num_scale = self.scale as i32 + numerator.scale as i32;
388        let den_scale = denominator.scale as i32;
389        let result_scale = num_scale - den_scale;
390        if result_scale < 0 || result_scale > MAX_SCALE as i32 {
391            return Err(ArithmeticError::ScaleExceeded);
392        }
393
394        Decimal::new(signed_mantissa, result_scale as u8)
395    }
396}
397
398// ─── Tests ───────────────────────────────────────────────────────────────────
399
400#[cfg(test)]
401mod tests {
402    use super::*;
403
404    #[test]
405    fn pow10_table_spot_checks() {
406        assert_eq!(pow10(0).unwrap(), 1);
407        assert_eq!(pow10(6).unwrap(), 1_000_000);
408        assert_eq!(pow10(18).unwrap(), 1_000_000_000_000_000_000);
409        assert_eq!(
410            pow10(28).unwrap(),
411            10_000_000_000_000_000_000_000_000_000i128
412        );
413        assert!(pow10(29).is_err());
414    }
415
416    #[test]
417    fn u256_mul_small() {
418        assert_eq!(U256::mul(3, 7), U256 { lo: 21, hi: 0 });
419    }
420
421    #[test]
422    fn u256_mul_max_times_max() {
423        let r = U256::mul(u128::MAX, u128::MAX);
424        assert_eq!(r.lo, 1);
425        assert_eq!(r.hi, u128::MAX - 1);
426    }
427
428    #[test]
429    fn u256_div_basic() {
430        assert_eq!(U256 { lo: 21, hi: 0 }.checked_div(7), Some((3, 0)));
431    }
432
433    #[test]
434    fn u256_div_by_zero() {
435        assert_eq!(U256::ZERO.checked_div(0), None);
436    }
437
438    #[test]
439    fn u256_div_overflow_check() {
440        assert_eq!(U256 { lo: 0, hi: 100 }.checked_div(50), None);
441    }
442}