Skip to main content

fermat_core/
arithmetic.rs

1//! Arithmetic operations: add, sub, mul, div, mul_div, neg, abs.
2//!
3//! # Scale Alignment
4//!
5//! Before addition/subtraction, both operands are scaled to the same number of
6//! decimal places (`max(a.scale, b.scale)`) via `align_scales`. This can itself
7//! overflow if the mantissa is near `i128::MAX`; that is reported as
8//! `ArithmeticError::Overflow`.
9//!
10//! # mul_div Safety
11//!
12//! `checked_mul_div` uses a 256-bit intermediate (`U256`) for `(a × b) / c` so
13//! that the product does not silently wrap. See [`U256`] for the implementation.
14
15use crate::decimal::{Decimal, MAX_SCALE};
16use crate::error::ArithmeticError;
17
18// ─── Powers of 10 table ──────────────────────────────────────────────────────
19
20/// Pre-computed powers of 10 from `10^0` to `10^28`.
21///
22/// Using a lookup instead of `checked_pow` eliminates runtime loops and
23/// guarantees O(1) access — important for CU budget on sBPF.
24pub(crate) const POW10: [i128; 29] = [
25    1,                                      // 10^0
26    10,                                     // 10^1
27    100,                                    // 10^2
28    1_000,                                  // 10^3
29    10_000,                                 // 10^4
30    100_000,                                // 10^5
31    1_000_000,                              // 10^6
32    10_000_000,                             // 10^7
33    100_000_000,                            // 10^8
34    1_000_000_000,                          // 10^9
35    10_000_000_000,                         // 10^10
36    100_000_000_000,                        // 10^11
37    1_000_000_000_000,                      // 10^12
38    10_000_000_000_000,                     // 10^13
39    100_000_000_000_000,                    // 10^14
40    1_000_000_000_000_000,                  // 10^15
41    10_000_000_000_000_000,                 // 10^16
42    100_000_000_000_000_000,                // 10^17
43    1_000_000_000_000_000_000,              // 10^18
44    10_000_000_000_000_000_000,             // 10^19
45    100_000_000_000_000_000_000,            // 10^20
46    1_000_000_000_000_000_000_000,          // 10^21
47    10_000_000_000_000_000_000_000,         // 10^22
48    100_000_000_000_000_000_000_000,        // 10^23
49    1_000_000_000_000_000_000_000_000,      // 10^24
50    10_000_000_000_000_000_000_000_000,     // 10^25
51    100_000_000_000_000_000_000_000_000,    // 10^26
52    1_000_000_000_000_000_000_000_000_000,  // 10^27
53    10_000_000_000_000_000_000_000_000_000, // 10^28
54];
55
56/// Return `10^exp` as `i128` or `Err(ScaleExceeded)` if `exp > MAX_SCALE`.
57#[inline]
58pub(crate) fn pow10(exp: u8) -> Result<i128, ArithmeticError> {
59    POW10.get(exp as usize).copied().ok_or(ArithmeticError::ScaleExceeded)
60}
61
62// ─── Scale alignment ─────────────────────────────────────────────────────────
63
64/// Align two operands to a common scale (`max(a.scale, b.scale)`).
65///
66/// Returns `(a_mantissa, b_mantissa, common_scale)`.
67/// Fails with `Overflow` if multiplying to scale up overflows `i128`.
68#[inline]
69pub(crate) fn align_scales(
70    a: Decimal,
71    b: Decimal,
72) -> Result<(i128, i128, u8), ArithmeticError> {
73    use core::cmp::Ordering;
74    match a.scale.cmp(&b.scale) {
75        Ordering::Equal => Ok((a.mantissa, b.mantissa, a.scale)),
76        Ordering::Less => {
77            let diff = b.scale - a.scale;
78            let factor = pow10(diff)?;
79            let scaled = a
80                .mantissa
81                .checked_mul(factor)
82                .ok_or(ArithmeticError::Overflow)?;
83            Ok((scaled, b.mantissa, b.scale))
84        }
85        Ordering::Greater => {
86            let diff = a.scale - b.scale;
87            let factor = pow10(diff)?;
88            let scaled = b
89                .mantissa
90                .checked_mul(factor)
91                .ok_or(ArithmeticError::Overflow)?;
92            Ok((a.mantissa, scaled, a.scale))
93        }
94    }
95}
96
97// ─── Sign helper for mul_div ──────────────────────────────────────────────────
98
99#[derive(Clone, Copy)]
100pub(crate) enum Sign {
101    Positive,
102    Negative,
103    Zero,
104}
105
106/// Compute the sign of `a * b / c` given their `i128` values.
107#[inline]
108pub(crate) fn sign3(a: i128, b: i128, c: i128) -> Sign {
109    if a == 0 || b == 0 {
110        return Sign::Zero;
111    }
112    let neg_a = a < 0;
113    let neg_b = b < 0;
114    let neg_c = c < 0;
115    let negative = (neg_a ^ neg_b) ^ neg_c;
116    if negative { Sign::Negative } else { Sign::Positive }
117}
118
119// ─── 256-bit unsigned integer ─────────────────────────────────────────────────
120
121/// 256-bit unsigned integer represented as two 128-bit limbs.
122///
123/// Used exclusively as an intermediate type in [`Decimal::checked_mul_div`] to
124/// prevent the product `a × b` from silently overflowing `i128` / `u128`.
125///
126/// Layout: `value = hi * 2^128 + lo`
127#[derive(Debug, Clone, Copy, PartialEq, Eq)]
128pub(crate) struct U256 {
129    /// Least-significant 128 bits.
130    pub lo: u128,
131    /// Most-significant 128 bits.
132    pub hi: u128,
133}
134
135impl U256 {
136    pub const ZERO: Self = Self { lo: 0, hi: 0 };
137
138    /// Exact 128 × 128 → 256-bit multiplication using four 64-bit limbs.
139    ///
140    /// ```text
141    /// a = a_hi * 2^64 + a_lo
142    /// b = b_hi * 2^64 + b_lo
143    /// a*b = hh * 2^128 + mid * 2^64 + ll
144    ///     where mid = a_hi*b_lo + a_lo*b_hi
145    /// ```
146    pub fn mul(a: u128, b: u128) -> Self {
147        const MASK64: u128 = u64::MAX as u128;
148        let a_lo = a & MASK64;
149        let a_hi = a >> 64;
150        let b_lo = b & MASK64;
151        let b_hi = b >> 64;
152
153        let ll = a_lo * b_lo;
154        let lh = a_lo * b_hi;
155        let hl = a_hi * b_lo;
156        let hh = a_hi * b_hi;
157
158        let (mid, mid_carry) = lh.overflowing_add(hl);
159        let (lo, lo_carry) = ll.overflowing_add(mid << 64);
160        let hi = hh
161            .wrapping_add(mid >> 64)
162            .wrapping_add(if mid_carry { 1u128 << 64 } else { 0 })
163            .wrapping_add(lo_carry as u128);
164
165        U256 { lo, hi }
166    }
167
168    /// 256-bit / 128-bit → `(quotient: u128, remainder: u128)`.
169    ///
170    /// Returns `None` if `d == 0` or `self.hi >= d` (quotient exceeds `u128`).
171    ///
172    /// ## Algorithm selection
173    ///
174    /// - **hi == 0**: simple 128-bit division (O(1)).
175    /// - **d ≤ u64::MAX**: four-phase 64-bit long-division (O(1), fast path for
176    ///   all realistic financial values where `d` is a token amount or scale factor).
177    /// - **d > u64::MAX**: binary long-division over 256 bits (O(256), always
178    ///   correct; rarely reached in DeFi contexts).
179    pub fn checked_div(self, d: u128) -> Option<(u128, u128)> {
180        if d == 0 {
181            return None;
182        }
183        // Fast path: numerator fits in 128 bits
184        if self.hi == 0 {
185            return Some((self.lo / d, self.lo % d));
186        }
187        // Quotient overflow guard
188        if self.hi >= d {
189            return None;
190        }
191
192        // ── Fast path: d fits in 64 bits ─────────────────────────────────────
193        // The four-phase algorithm computes (r * 2^64 + digit) / d in each phase.
194        // This is safe iff (r * 2^64) doesn't overflow u128, i.e., r < 2^64.
195        // Since r < d and d ≤ 2^64, r < 2^64. ✓
196        if d <= u64::MAX as u128 {
197            const HALF: u128 = 1u128 << 64;
198            const MASK: u128 = HALF - 1;
199
200            let hi_hi = self.hi >> 64;
201            let hi_lo = self.hi & MASK;
202            let lo_hi = self.lo >> 64;
203            let lo_lo = self.lo & MASK;
204
205            let r_a = hi_hi % d;
206            let q_a = hi_hi / d;
207
208            let n_b = r_a * HALF + hi_lo;
209            let q_b = n_b / d;
210            let r_b = n_b % d;
211
212            let n_c = r_b * HALF + lo_hi;
213            let q_c = n_c / d;
214            let r_c = n_c % d;
215
216            let n_d = r_c * HALF + lo_lo;
217            let q_d = n_d / d;
218            let r_d = n_d % d;
219
220            if q_a != 0 || q_b != 0 {
221                return None; // quotient > u128::MAX
222            }
223
224            return Some((q_c * HALF + q_d, r_d));
225        }
226
227        // ── General case: d > 2^64, binary long-division ─────────────────────
228        //
229        // Processes all 256 bits of the numerator from MSB to LSB.
230        // Maintains invariant: r < d at the end of every iteration.
231        //
232        // When `r_hi` (the overflow bit from `r << 1`) is set, the actual
233        // remainder is `2^128 + r_new`, which is guaranteed to be ≥ d and
234        // < 2d (proved from r < d before shift). The wrapping subtraction
235        // `r_new.wrapping_sub(d)` correctly computes `2^128 + r_new - d`.
236        let mut q: u128 = 0;
237        let mut r: u128 = 0;
238
239        for i in (0..256_u32).rev() {
240            let bit: u128 = if i >= 128 {
241                (self.hi >> (i - 128)) & 1
242            } else {
243                (self.lo >> i) & 1
244            };
245
246            let r_hi = r >> 127; // top bit of r (will overflow into bit 128 after shift)
247            let r_new = (r << 1) | bit;
248
249            if r_hi == 1 {
250                // Actual value is 2^128 + r_new; it must be ≥ d (and < 2d).
251                // wrapping_sub gives (2^128 + r_new - d) mod 2^128 = correct result.
252                r = r_new.wrapping_sub(d);
253                if i < 128 {
254                    q |= 1u128 << i;
255                }
256            } else if r_new >= d {
257                r = r_new - d;
258                if i < 128 {
259                    q |= 1u128 << i;
260                }
261            } else {
262                r = r_new;
263            }
264        }
265
266        Some((q, r))
267    }
268}
269
270// ─── Decimal: mul, div, neg, abs, mul_div ────────────────────────────────────
271
272impl Decimal {
273    /// Checked addition. Aligns scales then adds mantissas.
274    pub fn checked_add(self, rhs: Decimal) -> Result<Decimal, ArithmeticError> {
275        let (a, b, scale) = align_scales(self, rhs)?;
276        let mantissa = a.checked_add(b).ok_or(ArithmeticError::Overflow)?;
277        Decimal::new(mantissa, scale)
278    }
279
280    /// Checked subtraction. Aligns scales then subtracts mantissas.
281    pub fn checked_sub(self, rhs: Decimal) -> Result<Decimal, ArithmeticError> {
282        let (a, b, scale) = align_scales(self, rhs)?;
283        let mantissa = a.checked_sub(b).ok_or(ArithmeticError::Overflow)?;
284        Decimal::new(mantissa, scale)
285    }
286
287    /// Checked multiplication: `self * rhs`.
288    ///
289    /// Result scale = `self.scale + rhs.scale`.
290    /// Returns `Err(ScaleExceeded)` if that exceeds `MAX_SCALE`.
291    pub fn checked_mul(self, rhs: Decimal) -> Result<Decimal, ArithmeticError> {
292        let new_scale = self
293            .scale
294            .checked_add(rhs.scale)
295            .filter(|&s| s <= MAX_SCALE)
296            .ok_or(ArithmeticError::ScaleExceeded)?;
297        let mantissa = self
298            .mantissa
299            .checked_mul(rhs.mantissa)
300            .ok_or(ArithmeticError::Overflow)?;
301        Decimal::new(mantissa, new_scale)
302    }
303
304    /// Checked division: `self / rhs`.
305    ///
306    /// Scales the numerator up by `MAX_SCALE - self.scale` places before
307    /// dividing to retain maximum precision.
308    pub fn checked_div(self, rhs: Decimal) -> Result<Decimal, ArithmeticError> {
309        if rhs.mantissa == 0 {
310            return Err(ArithmeticError::DivisionByZero);
311        }
312        let extra = MAX_SCALE.saturating_sub(self.scale);
313        let factor = pow10(extra)?;
314        let scaled_num = self
315            .mantissa
316            .checked_mul(factor)
317            .ok_or(ArithmeticError::Overflow)?;
318        let mantissa = scaled_num
319            .checked_div(rhs.mantissa)
320            .ok_or(ArithmeticError::Overflow)?;
321        let raw_scale = (self.scale as i32) + (extra as i32) - (rhs.scale as i32);
322        if raw_scale < 0 {
323            return Err(ArithmeticError::Underflow);
324        }
325        Decimal::new(mantissa, (raw_scale as u8).min(MAX_SCALE))
326    }
327
328    /// Negation: returns `-self`.
329    ///
330    /// Fails with `Err(Overflow)` for `Decimal::MIN` (two's-complement has no
331    /// positive counterpart for `i128::MIN`).
332    pub fn checked_neg(self) -> Result<Decimal, ArithmeticError> {
333        let mantissa = self.mantissa.checked_neg().ok_or(ArithmeticError::Overflow)?;
334        Decimal::new(mantissa, self.scale)
335    }
336
337    /// Absolute value: returns `|self|`.
338    ///
339    /// Fails with `Err(Overflow)` for `Decimal::MIN`.
340    pub fn checked_abs(self) -> Result<Decimal, ArithmeticError> {
341        if self.mantissa >= 0 {
342            return Ok(self);
343        }
344        self.checked_neg()
345    }
346
347    /// Compound `(self × numerator) / denominator` with 256-bit intermediate.
348    ///
349    /// Prevents silent overflow that occurs when `self × numerator` exceeds
350    /// `i128::MAX` in a naive two-step `mul` then `div`.
351    pub fn checked_mul_div(
352        self,
353        numerator: Decimal,
354        denominator: Decimal,
355    ) -> Result<Decimal, ArithmeticError> {
356        if denominator.mantissa == 0 {
357            return Err(ArithmeticError::DivisionByZero);
358        }
359
360        let sign = sign3(self.mantissa, numerator.mantissa, denominator.mantissa);
361
362        let a = self.mantissa.unsigned_abs();
363        let b = numerator.mantissa.unsigned_abs();
364        let c = denominator.mantissa.unsigned_abs();
365
366        let product = U256::mul(a, b);
367        let (quotient_u128, _rem) =
368            product.checked_div(c).ok_or(ArithmeticError::Overflow)?;
369
370        let mantissa_abs =
371            i128::try_from(quotient_u128).map_err(|_| ArithmeticError::Overflow)?;
372
373        let signed_mantissa = match sign {
374            Sign::Zero => 0i128,
375            Sign::Positive => mantissa_abs,
376            Sign::Negative => {
377                mantissa_abs.checked_neg().ok_or(ArithmeticError::Overflow)?
378            }
379        };
380
381        let num_scale = self.scale as i32 + numerator.scale as i32;
382        let den_scale = denominator.scale as i32;
383        let result_scale = num_scale - den_scale;
384        if result_scale < 0 || result_scale > MAX_SCALE as i32 {
385            return Err(ArithmeticError::ScaleExceeded);
386        }
387
388        Decimal::new(signed_mantissa, result_scale as u8)
389    }
390}
391
392// ─── Tests ───────────────────────────────────────────────────────────────────
393
394#[cfg(test)]
395mod tests {
396    use super::*;
397
398    #[test]
399    fn pow10_table_spot_checks() {
400        assert_eq!(pow10(0).unwrap(), 1);
401        assert_eq!(pow10(6).unwrap(), 1_000_000);
402        assert_eq!(pow10(18).unwrap(), 1_000_000_000_000_000_000);
403        assert_eq!(
404            pow10(28).unwrap(),
405            10_000_000_000_000_000_000_000_000_000i128
406        );
407        assert!(pow10(29).is_err());
408    }
409
410    #[test]
411    fn u256_mul_small() {
412        assert_eq!(U256::mul(3, 7), U256 { lo: 21, hi: 0 });
413    }
414
415    #[test]
416    fn u256_mul_max_times_max() {
417        let r = U256::mul(u128::MAX, u128::MAX);
418        assert_eq!(r.lo, 1);
419        assert_eq!(r.hi, u128::MAX - 1);
420    }
421
422    #[test]
423    fn u256_div_basic() {
424        assert_eq!(U256 { lo: 21, hi: 0 }.checked_div(7), Some((3, 0)));
425    }
426
427    #[test]
428    fn u256_div_by_zero() {
429        assert_eq!(U256::ZERO.checked_div(0), None);
430    }
431
432    #[test]
433    fn u256_div_overflow_check() {
434        assert_eq!(U256 { lo: 0, hi: 100 }.checked_div(50), None);
435    }
436}