1use crate::decimal::{Decimal, MAX_SCALE};
16use crate::error::ArithmeticError;
17
18pub(crate) const POW10: [i128; 29] = [
25 1, 10, 100, 1_000, 10_000, 100_000, 1_000_000, 10_000_000, 100_000_000, 1_000_000_000, 10_000_000_000, 100_000_000_000, 1_000_000_000_000, 10_000_000_000_000, 100_000_000_000_000, 1_000_000_000_000_000, 10_000_000_000_000_000, 100_000_000_000_000_000, 1_000_000_000_000_000_000, 10_000_000_000_000_000_000, 100_000_000_000_000_000_000, 1_000_000_000_000_000_000_000, 10_000_000_000_000_000_000_000, 100_000_000_000_000_000_000_000, 1_000_000_000_000_000_000_000_000, 10_000_000_000_000_000_000_000_000, 100_000_000_000_000_000_000_000_000, 1_000_000_000_000_000_000_000_000_000, 10_000_000_000_000_000_000_000_000_000, ];
55
56#[inline]
58pub(crate) fn pow10(exp: u8) -> Result<i128, ArithmeticError> {
59 POW10
60 .get(exp as usize)
61 .copied()
62 .ok_or(ArithmeticError::ScaleExceeded)
63}
64
65#[inline]
72pub(crate) fn align_scales(a: Decimal, b: Decimal) -> Result<(i128, i128, u8), ArithmeticError> {
73 use core::cmp::Ordering;
74 match a.scale.cmp(&b.scale) {
75 Ordering::Equal => Ok((a.mantissa, b.mantissa, a.scale)),
76 Ordering::Less => {
77 let diff = b.scale - a.scale;
78 let factor = pow10(diff)?;
79 let scaled = a
80 .mantissa
81 .checked_mul(factor)
82 .ok_or(ArithmeticError::Overflow)?;
83 Ok((scaled, b.mantissa, b.scale))
84 }
85 Ordering::Greater => {
86 let diff = a.scale - b.scale;
87 let factor = pow10(diff)?;
88 let scaled = b
89 .mantissa
90 .checked_mul(factor)
91 .ok_or(ArithmeticError::Overflow)?;
92 Ok((a.mantissa, scaled, a.scale))
93 }
94 }
95}
96
97#[derive(Clone, Copy)]
100pub(crate) enum Sign {
101 Positive,
102 Negative,
103 Zero,
104}
105
106#[inline]
108pub(crate) fn sign3(a: i128, b: i128, c: i128) -> Sign {
109 if a == 0 || b == 0 {
110 return Sign::Zero;
111 }
112 let neg_a = a < 0;
113 let neg_b = b < 0;
114 let neg_c = c < 0;
115 let negative = (neg_a ^ neg_b) ^ neg_c;
116 if negative {
117 Sign::Negative
118 } else {
119 Sign::Positive
120 }
121}
122
123#[derive(Debug, Clone, Copy, PartialEq, Eq)]
132pub(crate) struct U256 {
133 pub lo: u128,
135 pub hi: u128,
137}
138
139impl U256 {
140 #[cfg(test)]
141 pub const ZERO: Self = Self { lo: 0, hi: 0 };
142
143 pub fn mul(a: u128, b: u128) -> Self {
152 const MASK64: u128 = u64::MAX as u128;
153 let a_lo = a & MASK64;
154 let a_hi = a >> 64;
155 let b_lo = b & MASK64;
156 let b_hi = b >> 64;
157
158 let ll = a_lo * b_lo;
159 let lh = a_lo * b_hi;
160 let hl = a_hi * b_lo;
161 let hh = a_hi * b_hi;
162
163 let (mid, mid_carry) = lh.overflowing_add(hl);
164 let (lo, lo_carry) = ll.overflowing_add(mid << 64);
165 let hi = hh
166 .wrapping_add(mid >> 64)
167 .wrapping_add(if mid_carry { 1u128 << 64 } else { 0 })
168 .wrapping_add(lo_carry as u128);
169
170 U256 { lo, hi }
171 }
172
173 pub fn checked_div(self, d: u128) -> Option<(u128, u128)> {
185 if d == 0 {
186 return None;
187 }
188 if self.hi == 0 {
190 return Some((self.lo / d, self.lo % d));
191 }
192 if self.hi >= d {
194 return None;
195 }
196
197 if d <= u64::MAX as u128 {
202 const HALF: u128 = 1u128 << 64;
203 const MASK: u128 = HALF - 1;
204
205 let hi_hi = self.hi >> 64;
206 let hi_lo = self.hi & MASK;
207 let lo_hi = self.lo >> 64;
208 let lo_lo = self.lo & MASK;
209
210 let r_a = hi_hi % d;
211 let q_a = hi_hi / d;
212
213 let n_b = r_a * HALF + hi_lo;
214 let q_b = n_b / d;
215 let r_b = n_b % d;
216
217 let n_c = r_b * HALF + lo_hi;
218 let q_c = n_c / d;
219 let r_c = n_c % d;
220
221 let n_d = r_c * HALF + lo_lo;
222 let q_d = n_d / d;
223 let r_d = n_d % d;
224
225 if q_a != 0 || q_b != 0 {
226 return None; }
228
229 return Some((q_c * HALF + q_d, r_d));
230 }
231
232 let mut q: u128 = 0;
242 let mut r: u128 = 0;
243
244 for i in (0..256_u32).rev() {
245 let bit: u128 = if i >= 128 {
246 (self.hi >> (i - 128)) & 1
247 } else {
248 (self.lo >> i) & 1
249 };
250
251 let r_hi = r >> 127; let r_new = (r << 1) | bit;
253
254 if r_hi == 1 {
255 r = r_new.wrapping_sub(d);
258 if i < 128 {
259 q |= 1u128 << i;
260 }
261 } else if r_new >= d {
262 r = r_new - d;
263 if i < 128 {
264 q |= 1u128 << i;
265 }
266 } else {
267 r = r_new;
268 }
269 }
270
271 Some((q, r))
272 }
273}
274
275impl Decimal {
278 pub fn checked_add(self, rhs: Decimal) -> Result<Decimal, ArithmeticError> {
280 let (a, b, scale) = align_scales(self, rhs)?;
281 let mantissa = a.checked_add(b).ok_or(ArithmeticError::Overflow)?;
282 Decimal::new(mantissa, scale)
283 }
284
285 pub fn checked_sub(self, rhs: Decimal) -> Result<Decimal, ArithmeticError> {
287 let (a, b, scale) = align_scales(self, rhs)?;
288 let mantissa = a.checked_sub(b).ok_or(ArithmeticError::Overflow)?;
289 Decimal::new(mantissa, scale)
290 }
291
292 pub fn checked_mul(self, rhs: Decimal) -> Result<Decimal, ArithmeticError> {
297 let new_scale = self
298 .scale
299 .checked_add(rhs.scale)
300 .filter(|&s| s <= MAX_SCALE)
301 .ok_or(ArithmeticError::ScaleExceeded)?;
302 let mantissa = self
303 .mantissa
304 .checked_mul(rhs.mantissa)
305 .ok_or(ArithmeticError::Overflow)?;
306 Decimal::new(mantissa, new_scale)
307 }
308
309 pub fn checked_div(self, rhs: Decimal) -> Result<Decimal, ArithmeticError> {
314 if rhs.mantissa == 0 {
315 return Err(ArithmeticError::DivisionByZero);
316 }
317 let extra = MAX_SCALE.saturating_sub(self.scale);
318 let factor = pow10(extra)?;
319 let scaled_num = self
320 .mantissa
321 .checked_mul(factor)
322 .ok_or(ArithmeticError::Overflow)?;
323 let mantissa = scaled_num
324 .checked_div(rhs.mantissa)
325 .ok_or(ArithmeticError::Overflow)?;
326 let raw_scale = (self.scale as i32) + (extra as i32) - (rhs.scale as i32);
327 if raw_scale < 0 {
328 return Err(ArithmeticError::Underflow);
329 }
330 Decimal::new(mantissa, (raw_scale as u8).min(MAX_SCALE))
331 }
332
333 pub fn checked_neg(self) -> Result<Decimal, ArithmeticError> {
338 let mantissa = self
339 .mantissa
340 .checked_neg()
341 .ok_or(ArithmeticError::Overflow)?;
342 Decimal::new(mantissa, self.scale)
343 }
344
345 pub fn checked_abs(self) -> Result<Decimal, ArithmeticError> {
349 if self.mantissa >= 0 {
350 return Ok(self);
351 }
352 self.checked_neg()
353 }
354
355 pub fn checked_mul_div(
360 self,
361 numerator: Decimal,
362 denominator: Decimal,
363 ) -> Result<Decimal, ArithmeticError> {
364 if denominator.mantissa == 0 {
365 return Err(ArithmeticError::DivisionByZero);
366 }
367
368 let sign = sign3(self.mantissa, numerator.mantissa, denominator.mantissa);
369
370 let a = self.mantissa.unsigned_abs();
371 let b = numerator.mantissa.unsigned_abs();
372 let c = denominator.mantissa.unsigned_abs();
373
374 let product = U256::mul(a, b);
375 let (quotient_u128, _rem) = product.checked_div(c).ok_or(ArithmeticError::Overflow)?;
376
377 let mantissa_abs = i128::try_from(quotient_u128).map_err(|_| ArithmeticError::Overflow)?;
378
379 let signed_mantissa = match sign {
380 Sign::Zero => 0i128,
381 Sign::Positive => mantissa_abs,
382 Sign::Negative => mantissa_abs
383 .checked_neg()
384 .ok_or(ArithmeticError::Overflow)?,
385 };
386
387 let num_scale = self.scale as i32 + numerator.scale as i32;
388 let den_scale = denominator.scale as i32;
389 let result_scale = num_scale - den_scale;
390 if result_scale < 0 || result_scale > MAX_SCALE as i32 {
391 return Err(ArithmeticError::ScaleExceeded);
392 }
393
394 Decimal::new(signed_mantissa, result_scale as u8)
395 }
396}
397
398#[cfg(test)]
401mod tests {
402 use super::*;
403
404 #[test]
405 fn pow10_table_spot_checks() {
406 assert_eq!(pow10(0).unwrap(), 1);
407 assert_eq!(pow10(6).unwrap(), 1_000_000);
408 assert_eq!(pow10(18).unwrap(), 1_000_000_000_000_000_000);
409 assert_eq!(
410 pow10(28).unwrap(),
411 10_000_000_000_000_000_000_000_000_000i128
412 );
413 assert!(pow10(29).is_err());
414 }
415
416 #[test]
417 fn u256_mul_small() {
418 assert_eq!(U256::mul(3, 7), U256 { lo: 21, hi: 0 });
419 }
420
421 #[test]
422 fn u256_mul_max_times_max() {
423 let r = U256::mul(u128::MAX, u128::MAX);
424 assert_eq!(r.lo, 1);
425 assert_eq!(r.hi, u128::MAX - 1);
426 }
427
428 #[test]
429 fn u256_div_basic() {
430 assert_eq!(U256 { lo: 21, hi: 0 }.checked_div(7), Some((3, 0)));
431 }
432
433 #[test]
434 fn u256_div_by_zero() {
435 assert_eq!(U256::ZERO.checked_div(0), None);
436 }
437
438 #[test]
439 fn u256_div_overflow_check() {
440 assert_eq!(U256 { lo: 0, hi: 100 }.checked_div(50), None);
441 }
442}