jiminy/math.rs
1use pinocchio::error::ProgramError;
2
3/// Checked u64 addition: returns `ArithmeticOverflow` on overflow.
4#[inline(always)]
5pub fn checked_add(a: u64, b: u64) -> Result<u64, ProgramError> {
6 a.checked_add(b).ok_or(ProgramError::ArithmeticOverflow)
7}
8
9/// Checked u64 subtraction: returns `ArithmeticOverflow` on underflow.
10#[inline(always)]
11pub fn checked_sub(a: u64, b: u64) -> Result<u64, ProgramError> {
12 a.checked_sub(b).ok_or(ProgramError::ArithmeticOverflow)
13}
14
15/// Checked u64 multiplication: returns `ArithmeticOverflow` on overflow.
16#[inline(always)]
17pub fn checked_mul(a: u64, b: u64) -> Result<u64, ProgramError> {
18 a.checked_mul(b).ok_or(ProgramError::ArithmeticOverflow)
19}
20
21/// Checked u64 division: returns `ArithmeticOverflow` on divide-by-zero.
22///
23/// Every AMM price calculation involves division. This is the missing
24/// companion to `checked_add`/`checked_sub`/`checked_mul`.
25///
26/// ```rust,ignore
27/// let price = checked_div(reserve_b, reserve_a)?;
28/// ```
29#[inline(always)]
30pub fn checked_div(a: u64, b: u64) -> Result<u64, ProgramError> {
31 a.checked_div(b).ok_or(ProgramError::ArithmeticOverflow)
32}
33
34/// Checked ceiling division: `ceil(a / b)`. Returns `ArithmeticOverflow` on zero.
35///
36/// Rounds up instead of truncating. Use this for fee calculations and
37/// minimum-output computations where truncation would favor the user
38/// at the protocol's expense.
39///
40/// ```rust,ignore
41/// let fee = checked_div_ceil(amount * fee_rate, 10_000)?;
42/// ```
43#[inline(always)]
44pub fn checked_div_ceil(a: u64, b: u64) -> Result<u64, ProgramError> {
45 if b == 0 {
46 return Err(ProgramError::ArithmeticOverflow);
47 }
48 // ceil(a / b) = (a + b - 1) / b, guarding against overflow in a + b - 1
49 Ok(a.checked_add(b - 1)
50 .ok_or(ProgramError::ArithmeticOverflow)?
51 / b)
52}
53
54/// Compute `(a * b) / c` with u128 intermediate to prevent overflow.
55///
56/// **The single most important DeFi math primitive.** Without u128
57/// intermediate, `a * b` overflows for any token amounts > ~4.2B.
58/// Returns floor division.
59///
60/// ```rust,ignore
61/// // Constant-product swap: dy = (y * dx) / (x + dx)
62/// let output = checked_mul_div(reserve_y, input, reserve_x + input)?;
63/// ```
64#[inline(always)]
65pub fn checked_mul_div(a: u64, b: u64, c: u64) -> Result<u64, ProgramError> {
66 if c == 0 {
67 return Err(ProgramError::ArithmeticOverflow);
68 }
69 let result = (a as u128)
70 .checked_mul(b as u128)
71 .ok_or(ProgramError::ArithmeticOverflow)?
72 / (c as u128);
73 to_u64(result)
74}
75
76/// Compute `ceil((a * b) / c)` with u128 intermediate.
77///
78/// Same as `checked_mul_div` but rounds up. Use this for fee calculations
79/// to ensure the protocol never gets rounded down to zero fee.
80///
81/// ```rust,ignore
82/// let fee = checked_mul_div_ceil(amount, fee_bps, 10_000)?;
83/// ```
84#[inline(always)]
85pub fn checked_mul_div_ceil(a: u64, b: u64, c: u64) -> Result<u64, ProgramError> {
86 if c == 0 {
87 return Err(ProgramError::ArithmeticOverflow);
88 }
89 let numerator = (a as u128)
90 .checked_mul(b as u128)
91 .ok_or(ProgramError::ArithmeticOverflow)?;
92 let c128 = c as u128;
93 // ceil(n / d) = (n + d - 1) / d
94 let result = numerator
95 .checked_add(c128 - 1)
96 .ok_or(ProgramError::ArithmeticOverflow)?
97 / c128;
98 to_u64(result)
99}
100
101/// Compute basis-point fee: `amount * bps / 10_000` (floor).
102///
103/// Uses u128 intermediate to prevent overflow. Nearly every DeFi program
104/// computes fees in basis points — this one-liner eliminates a whole class
105/// of bugs.
106///
107/// ```rust,ignore
108/// let fee = bps_of(trade_amount, 30)?; // 0.3% fee
109/// ```
110#[inline(always)]
111pub fn bps_of(amount: u64, basis_points: u16) -> Result<u64, ProgramError> {
112 checked_mul_div(amount, basis_points as u64, 10_000)
113}
114
115/// Compute basis-point fee with ceiling: `ceil(amount * bps / 10_000)`.
116///
117/// Fees must never round to zero. Use this to ensure the protocol always
118/// collects at least 1 token unit of fee when a fee is configured.
119///
120/// ```rust,ignore
121/// let fee = bps_of_ceil(trade_amount, 30)?; // 0.3% fee, always >= 1
122/// ```
123#[inline(always)]
124pub fn bps_of_ceil(amount: u64, basis_points: u16) -> Result<u64, ProgramError> {
125 checked_mul_div_ceil(amount, basis_points as u64, 10_000)
126}
127
128/// Checked exponentiation via repeated squaring.
129///
130/// Computes `base^exp` with overflow checking at each step. Useful for
131/// compound interest calculations and exponential decay.
132///
133/// ```rust,ignore
134/// let compound = checked_pow(rate_per_period, num_periods)?;
135/// ```
136#[inline(always)]
137pub fn checked_pow(base: u64, exp: u32) -> Result<u64, ProgramError> {
138 if exp == 0 {
139 return Ok(1);
140 }
141 let mut result: u64 = 1;
142 let mut b = base;
143 let mut e = exp;
144 while e > 0 {
145 if e & 1 == 1 {
146 result = result.checked_mul(b).ok_or(ProgramError::ArithmeticOverflow)?;
147 }
148 e >>= 1;
149 if e > 0 {
150 b = b.checked_mul(b).ok_or(ProgramError::ArithmeticOverflow)?;
151 }
152 }
153 Ok(result)
154}
155
156/// Safe narrowing cast from u128 to u64.
157///
158/// Returns `ArithmeticOverflow` if the value exceeds `u64::MAX`.
159/// Use this after u128 intermediate computations.
160///
161/// ```rust,ignore
162/// let result_u128: u128 = (a as u128) * (b as u128) / (c as u128);
163/// let result = to_u64(result_u128)?;
164/// ```
165#[inline(always)]
166pub fn to_u64(val: u128) -> Result<u64, ProgramError> {
167 if val > u64::MAX as u128 {
168 return Err(ProgramError::ArithmeticOverflow);
169 }
170 Ok(val as u64)
171}