1use crate::decimal::{Decimal, MAX_SCALE};
16use crate::error::ArithmeticError;
17
18pub(crate) const POW10: [i128; 29] = [
25 1, 10, 100, 1_000, 10_000, 100_000, 1_000_000, 10_000_000, 100_000_000, 1_000_000_000, 10_000_000_000, 100_000_000_000, 1_000_000_000_000, 10_000_000_000_000, 100_000_000_000_000, 1_000_000_000_000_000, 10_000_000_000_000_000, 100_000_000_000_000_000, 1_000_000_000_000_000_000, 10_000_000_000_000_000_000, 100_000_000_000_000_000_000, 1_000_000_000_000_000_000_000, 10_000_000_000_000_000_000_000, 100_000_000_000_000_000_000_000, 1_000_000_000_000_000_000_000_000, 10_000_000_000_000_000_000_000_000, 100_000_000_000_000_000_000_000_000, 1_000_000_000_000_000_000_000_000_000, 10_000_000_000_000_000_000_000_000_000, ];
55
56#[inline]
58pub(crate) fn pow10(exp: u8) -> Result<i128, ArithmeticError> {
59 POW10.get(exp as usize).copied().ok_or(ArithmeticError::ScaleExceeded)
60}
61
62#[inline]
69pub(crate) fn align_scales(
70 a: Decimal,
71 b: Decimal,
72) -> Result<(i128, i128, u8), ArithmeticError> {
73 use core::cmp::Ordering;
74 match a.scale.cmp(&b.scale) {
75 Ordering::Equal => Ok((a.mantissa, b.mantissa, a.scale)),
76 Ordering::Less => {
77 let diff = b.scale - a.scale;
78 let factor = pow10(diff)?;
79 let scaled = a
80 .mantissa
81 .checked_mul(factor)
82 .ok_or(ArithmeticError::Overflow)?;
83 Ok((scaled, b.mantissa, b.scale))
84 }
85 Ordering::Greater => {
86 let diff = a.scale - b.scale;
87 let factor = pow10(diff)?;
88 let scaled = b
89 .mantissa
90 .checked_mul(factor)
91 .ok_or(ArithmeticError::Overflow)?;
92 Ok((a.mantissa, scaled, a.scale))
93 }
94 }
95}
96
97#[derive(Clone, Copy)]
100pub(crate) enum Sign {
101 Positive,
102 Negative,
103 Zero,
104}
105
106#[inline]
108pub(crate) fn sign3(a: i128, b: i128, c: i128) -> Sign {
109 if a == 0 || b == 0 {
110 return Sign::Zero;
111 }
112 let neg_a = a < 0;
113 let neg_b = b < 0;
114 let neg_c = c < 0;
115 let negative = (neg_a ^ neg_b) ^ neg_c;
116 if negative { Sign::Negative } else { Sign::Positive }
117}
118
119#[derive(Debug, Clone, Copy, PartialEq, Eq)]
128pub(crate) struct U256 {
129 pub lo: u128,
131 pub hi: u128,
133}
134
135impl U256 {
136 pub const ZERO: Self = Self { lo: 0, hi: 0 };
137
138 pub fn mul(a: u128, b: u128) -> Self {
147 const MASK64: u128 = u64::MAX as u128;
148 let a_lo = a & MASK64;
149 let a_hi = a >> 64;
150 let b_lo = b & MASK64;
151 let b_hi = b >> 64;
152
153 let ll = a_lo * b_lo;
154 let lh = a_lo * b_hi;
155 let hl = a_hi * b_lo;
156 let hh = a_hi * b_hi;
157
158 let (mid, mid_carry) = lh.overflowing_add(hl);
159 let (lo, lo_carry) = ll.overflowing_add(mid << 64);
160 let hi = hh
161 .wrapping_add(mid >> 64)
162 .wrapping_add(if mid_carry { 1u128 << 64 } else { 0 })
163 .wrapping_add(lo_carry as u128);
164
165 U256 { lo, hi }
166 }
167
168 pub fn checked_div(self, d: u128) -> Option<(u128, u128)> {
180 if d == 0 {
181 return None;
182 }
183 if self.hi == 0 {
185 return Some((self.lo / d, self.lo % d));
186 }
187 if self.hi >= d {
189 return None;
190 }
191
192 if d <= u64::MAX as u128 {
197 const HALF: u128 = 1u128 << 64;
198 const MASK: u128 = HALF - 1;
199
200 let hi_hi = self.hi >> 64;
201 let hi_lo = self.hi & MASK;
202 let lo_hi = self.lo >> 64;
203 let lo_lo = self.lo & MASK;
204
205 let r_a = hi_hi % d;
206 let q_a = hi_hi / d;
207
208 let n_b = r_a * HALF + hi_lo;
209 let q_b = n_b / d;
210 let r_b = n_b % d;
211
212 let n_c = r_b * HALF + lo_hi;
213 let q_c = n_c / d;
214 let r_c = n_c % d;
215
216 let n_d = r_c * HALF + lo_lo;
217 let q_d = n_d / d;
218 let r_d = n_d % d;
219
220 if q_a != 0 || q_b != 0 {
221 return None; }
223
224 return Some((q_c * HALF + q_d, r_d));
225 }
226
227 let mut q: u128 = 0;
237 let mut r: u128 = 0;
238
239 for i in (0..256_u32).rev() {
240 let bit: u128 = if i >= 128 {
241 (self.hi >> (i - 128)) & 1
242 } else {
243 (self.lo >> i) & 1
244 };
245
246 let r_hi = r >> 127; let r_new = (r << 1) | bit;
248
249 if r_hi == 1 {
250 r = r_new.wrapping_sub(d);
253 if i < 128 {
254 q |= 1u128 << i;
255 }
256 } else if r_new >= d {
257 r = r_new - d;
258 if i < 128 {
259 q |= 1u128 << i;
260 }
261 } else {
262 r = r_new;
263 }
264 }
265
266 Some((q, r))
267 }
268}
269
270impl Decimal {
273 pub fn checked_add(self, rhs: Decimal) -> Result<Decimal, ArithmeticError> {
275 let (a, b, scale) = align_scales(self, rhs)?;
276 let mantissa = a.checked_add(b).ok_or(ArithmeticError::Overflow)?;
277 Decimal::new(mantissa, scale)
278 }
279
280 pub fn checked_sub(self, rhs: Decimal) -> Result<Decimal, ArithmeticError> {
282 let (a, b, scale) = align_scales(self, rhs)?;
283 let mantissa = a.checked_sub(b).ok_or(ArithmeticError::Overflow)?;
284 Decimal::new(mantissa, scale)
285 }
286
287 pub fn checked_mul(self, rhs: Decimal) -> Result<Decimal, ArithmeticError> {
292 let new_scale = self
293 .scale
294 .checked_add(rhs.scale)
295 .filter(|&s| s <= MAX_SCALE)
296 .ok_or(ArithmeticError::ScaleExceeded)?;
297 let mantissa = self
298 .mantissa
299 .checked_mul(rhs.mantissa)
300 .ok_or(ArithmeticError::Overflow)?;
301 Decimal::new(mantissa, new_scale)
302 }
303
304 pub fn checked_div(self, rhs: Decimal) -> Result<Decimal, ArithmeticError> {
309 if rhs.mantissa == 0 {
310 return Err(ArithmeticError::DivisionByZero);
311 }
312 let extra = MAX_SCALE.saturating_sub(self.scale);
313 let factor = pow10(extra)?;
314 let scaled_num = self
315 .mantissa
316 .checked_mul(factor)
317 .ok_or(ArithmeticError::Overflow)?;
318 let mantissa = scaled_num
319 .checked_div(rhs.mantissa)
320 .ok_or(ArithmeticError::Overflow)?;
321 let raw_scale = (self.scale as i32) + (extra as i32) - (rhs.scale as i32);
322 if raw_scale < 0 {
323 return Err(ArithmeticError::Underflow);
324 }
325 Decimal::new(mantissa, (raw_scale as u8).min(MAX_SCALE))
326 }
327
328 pub fn checked_neg(self) -> Result<Decimal, ArithmeticError> {
333 let mantissa = self.mantissa.checked_neg().ok_or(ArithmeticError::Overflow)?;
334 Decimal::new(mantissa, self.scale)
335 }
336
337 pub fn checked_abs(self) -> Result<Decimal, ArithmeticError> {
341 if self.mantissa >= 0 {
342 return Ok(self);
343 }
344 self.checked_neg()
345 }
346
347 pub fn checked_mul_div(
352 self,
353 numerator: Decimal,
354 denominator: Decimal,
355 ) -> Result<Decimal, ArithmeticError> {
356 if denominator.mantissa == 0 {
357 return Err(ArithmeticError::DivisionByZero);
358 }
359
360 let sign = sign3(self.mantissa, numerator.mantissa, denominator.mantissa);
361
362 let a = self.mantissa.unsigned_abs();
363 let b = numerator.mantissa.unsigned_abs();
364 let c = denominator.mantissa.unsigned_abs();
365
366 let product = U256::mul(a, b);
367 let (quotient_u128, _rem) =
368 product.checked_div(c).ok_or(ArithmeticError::Overflow)?;
369
370 let mantissa_abs =
371 i128::try_from(quotient_u128).map_err(|_| ArithmeticError::Overflow)?;
372
373 let signed_mantissa = match sign {
374 Sign::Zero => 0i128,
375 Sign::Positive => mantissa_abs,
376 Sign::Negative => {
377 mantissa_abs.checked_neg().ok_or(ArithmeticError::Overflow)?
378 }
379 };
380
381 let num_scale = self.scale as i32 + numerator.scale as i32;
382 let den_scale = denominator.scale as i32;
383 let result_scale = num_scale - den_scale;
384 if result_scale < 0 || result_scale > MAX_SCALE as i32 {
385 return Err(ArithmeticError::ScaleExceeded);
386 }
387
388 Decimal::new(signed_mantissa, result_scale as u8)
389 }
390}
391
392#[cfg(test)]
395mod tests {
396 use super::*;
397
398 #[test]
399 fn pow10_table_spot_checks() {
400 assert_eq!(pow10(0).unwrap(), 1);
401 assert_eq!(pow10(6).unwrap(), 1_000_000);
402 assert_eq!(pow10(18).unwrap(), 1_000_000_000_000_000_000);
403 assert_eq!(
404 pow10(28).unwrap(),
405 10_000_000_000_000_000_000_000_000_000i128
406 );
407 assert!(pow10(29).is_err());
408 }
409
410 #[test]
411 fn u256_mul_small() {
412 assert_eq!(U256::mul(3, 7), U256 { lo: 21, hi: 0 });
413 }
414
415 #[test]
416 fn u256_mul_max_times_max() {
417 let r = U256::mul(u128::MAX, u128::MAX);
418 assert_eq!(r.lo, 1);
419 assert_eq!(r.hi, u128::MAX - 1);
420 }
421
422 #[test]
423 fn u256_div_basic() {
424 assert_eq!(U256 { lo: 21, hi: 0 }.checked_div(7), Some((3, 0)));
425 }
426
427 #[test]
428 fn u256_div_by_zero() {
429 assert_eq!(U256::ZERO.checked_div(0), None);
430 }
431
432 #[test]
433 fn u256_div_overflow_check() {
434 assert_eq!(U256 { lo: 0, hi: 100 }.checked_div(50), None);
435 }
436}