Skip to main content

hopper_core/math/
mod.rs

1//! Checked arithmetic operations.
2//!
3//! All operations return `ProgramError::ArithmeticOverflow` on failure.
4//! No panics on-chain. u128 intermediates where needed to prevent
5//! premature overflow on real token amounts.
6
7use hopper_runtime::error::ProgramError;
8
9// ── Basic checked ops ────────────────────────────────────────────────────────
10
11/// Checked addition.
12#[inline(always)]
13pub fn checked_add(a: u64, b: u64) -> Result<u64, ProgramError> {
14    a.checked_add(b).ok_or(ProgramError::ArithmeticOverflow)
15}
16
17/// Checked subtraction.
18#[inline(always)]
19pub fn checked_sub(a: u64, b: u64) -> Result<u64, ProgramError> {
20    a.checked_sub(b).ok_or(ProgramError::ArithmeticOverflow)
21}
22
23/// Checked multiplication.
24#[inline(always)]
25pub fn checked_mul(a: u64, b: u64) -> Result<u64, ProgramError> {
26    a.checked_mul(b).ok_or(ProgramError::ArithmeticOverflow)
27}
28
29/// Checked division (returns error on divide by zero).
30#[inline(always)]
31pub fn checked_div(a: u64, b: u64) -> Result<u64, ProgramError> {
32    a.checked_div(b).ok_or(ProgramError::ArithmeticOverflow)
33}
34
35/// Checked addition for i64.
36#[inline(always)]
37pub fn checked_add_i64(a: i64, b: i64) -> Result<i64, ProgramError> {
38    a.checked_add(b).ok_or(ProgramError::ArithmeticOverflow)
39}
40
41/// Checked subtraction for i64.
42#[inline(always)]
43pub fn checked_sub_i64(a: i64, b: i64) -> Result<i64, ProgramError> {
44    a.checked_sub(b).ok_or(ProgramError::ArithmeticOverflow)
45}
46
47// ── Ceiling division ─────────────────────────────────────────────────────────
48
49/// Compute `ceil(a / b)` without overflow (for u64).
50#[inline(always)]
51pub fn div_ceil(a: u64, b: u64) -> Result<u64, ProgramError> {
52    if b == 0 {
53        return Err(ProgramError::ArithmeticOverflow);
54    }
55    Ok(a.div_ceil(b))
56}
57
58/// Checked ceiling division: `ceil(a / b)`.
59///
60/// Rounds up instead of truncating. Use for fee calculations and minimum
61/// outputs where truncation would favor the user at the protocol's expense.
62#[inline(always)]
63pub fn checked_div_ceil(a: u64, b: u64) -> Result<u64, ProgramError> {
64    if b == 0 {
65        return Err(ProgramError::ArithmeticOverflow);
66    }
67    Ok(a.checked_add(b - 1)
68        .ok_or(ProgramError::ArithmeticOverflow)?
69        / b)
70}
71
72// ── u128-intermediate ops ────────────────────────────────────────────────────
73
74/// Compute `(a * b) / c` with u128 intermediate to prevent overflow.
75///
76/// **The core DeFi math primitive.** Without u128 intermediate, `a * b`
77/// overflows for any token amounts > ~4.2B (common with 9-decimal mints).
78/// Returns floor division.
79///
80/// ```rust,ignore
81/// // Constant-product swap: dy = (y * dx) / (x + dx)
82/// let output = checked_mul_div(reserve_y, input, reserve_x + input)?;
83/// ```
84#[inline(always)]
85pub fn checked_mul_div(a: u64, b: u64, c: u64) -> Result<u64, ProgramError> {
86    if c == 0 {
87        return Err(ProgramError::ArithmeticOverflow);
88    }
89    let result = (a as u128)
90        .checked_mul(b as u128)
91        .ok_or(ProgramError::ArithmeticOverflow)?
92        / (c as u128);
93    to_u64(result)
94}
95
96/// Compute `ceil((a * b) / c)` with u128 intermediate.
97///
98/// Same as [`checked_mul_div`] but rounds up. Use for fee calculations
99/// so the protocol never rounds down to zero fee.
100#[inline(always)]
101pub fn checked_mul_div_ceil(a: u64, b: u64, c: u64) -> Result<u64, ProgramError> {
102    if c == 0 {
103        return Err(ProgramError::ArithmeticOverflow);
104    }
105    let numerator = (a as u128)
106        .checked_mul(b as u128)
107        .ok_or(ProgramError::ArithmeticOverflow)?;
108    let c128 = c as u128;
109    let result = numerator
110        .checked_add(c128 - 1)
111        .ok_or(ProgramError::ArithmeticOverflow)?
112        / c128;
113    to_u64(result)
114}
115
116/// Safe narrowing cast from u128 to u64.
117#[inline(always)]
118pub fn to_u64(val: u128) -> Result<u64, ProgramError> {
119    if val > u64::MAX as u128 {
120        return Err(ProgramError::ArithmeticOverflow);
121    }
122    Ok(val as u64)
123}
124
125// ── Basis-point helpers ──────────────────────────────────────────────────────
126
127/// Scale a value in basis points (BPS).
128/// `value * bps / 10_000`, with overflow protection via u128 intermediate.
129#[inline(always)]
130pub fn scale_bps(value: u64, bps: u64) -> Result<u64, ProgramError> {
131    checked_mul_div(value, bps, 10_000)
132}
133
134/// Basis-point fee (floor): `amount * bps / 10_000`.
135///
136/// Nearly every DeFi program computes fees in basis points.
137#[inline(always)]
138pub fn bps_of(amount: u64, basis_points: u16) -> Result<u64, ProgramError> {
139    checked_mul_div(amount, basis_points as u64, 10_000)
140}
141
142/// Basis-point fee (ceiling): `ceil(amount * bps / 10_000)`.
143///
144/// Fees must never round to zero. Use this so the protocol always
145/// collects at least 1 token unit of fee when configured.
146#[inline(always)]
147pub fn bps_of_ceil(amount: u64, basis_points: u16) -> Result<u64, ProgramError> {
148    checked_mul_div_ceil(amount, basis_points as u64, 10_000)
149}
150
151/// Scale a value by a fraction `(numerator / denominator)`.
152#[inline(always)]
153pub fn scale_fraction(value: u64, numerator: u64, denominator: u64) -> Result<u64, ProgramError> {
154    checked_mul_div(value, numerator, denominator)
155}
156
157// ── Decimal scaling ──────────────────────────────────────────────────────────
158
159/// Scale a token amount between different decimal precisions (floor).
160///
161/// Converts `amount` denominated in `from_decimals` to the equivalent
162/// value in `to_decimals`. Uses u128 intermediate to prevent overflow.
163///
164/// ```rust,ignore
165/// let scaled = scale_amount(1_000_000, 6, 9)?; // USDC → SOL precision
166/// assert_eq!(scaled, 1_000_000_000);
167/// ```
168#[inline(always)]
169pub fn scale_amount(amount: u64, from_decimals: u8, to_decimals: u8) -> Result<u64, ProgramError> {
170    if from_decimals == to_decimals {
171        return Ok(amount);
172    }
173    if to_decimals > from_decimals {
174        let factor = 10u128
175            .checked_pow((to_decimals - from_decimals) as u32)
176            .ok_or(ProgramError::ArithmeticOverflow)?;
177        let result = (amount as u128)
178            .checked_mul(factor)
179            .ok_or(ProgramError::ArithmeticOverflow)?;
180        to_u64(result)
181    } else {
182        let factor = 10u64
183            .checked_pow((from_decimals - to_decimals) as u32)
184            .ok_or(ProgramError::ArithmeticOverflow)?;
185        checked_div(amount, factor)
186    }
187}
188
189/// Scale a token amount between decimal precisions, rounding up.
190///
191/// Same as [`scale_amount`] but uses ceiling division when scaling down.
192/// Use for protocol-side calculations where truncating would short-change
193/// the protocol (e.g., minimum collateral requirements).
194#[inline(always)]
195pub fn scale_amount_ceil(
196    amount: u64,
197    from_decimals: u8,
198    to_decimals: u8,
199) -> Result<u64, ProgramError> {
200    if from_decimals == to_decimals {
201        return Ok(amount);
202    }
203    if to_decimals > from_decimals {
204        let factor = 10u128
205            .checked_pow((to_decimals - from_decimals) as u32)
206            .ok_or(ProgramError::ArithmeticOverflow)?;
207        let result = (amount as u128)
208            .checked_mul(factor)
209            .ok_or(ProgramError::ArithmeticOverflow)?;
210        to_u64(result)
211    } else {
212        let factor = 10u64
213            .checked_pow((from_decimals - to_decimals) as u32)
214            .ok_or(ProgramError::ArithmeticOverflow)?;
215        checked_div_ceil(amount, factor)
216    }
217}
218
219// ── Exponentiation ───────────────────────────────────────────────────────────
220
221/// Checked exponentiation via repeated squaring.
222///
223/// Computes `base^exp` with overflow checking at each step. Useful for
224/// compound interest and exponential decay.
225#[inline(always)]
226pub fn checked_pow(base: u64, exp: u32) -> Result<u64, ProgramError> {
227    if exp == 0 {
228        return Ok(1);
229    }
230    let mut result: u64 = 1;
231    let mut b = base;
232    let mut e = exp;
233    while e > 0 {
234        if e & 1 == 1 {
235            result = result
236                .checked_mul(b)
237                .ok_or(ProgramError::ArithmeticOverflow)?;
238        }
239        e >>= 1;
240        if e > 0 {
241            b = b.checked_mul(b).ok_or(ProgramError::ArithmeticOverflow)?;
242        }
243    }
244    Ok(result)
245}