1use crate::{conversions::to_u32, errors::ParclMathErrorCode, uint::U256};
2use anchor_lang::prelude::*;
3use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign};
4
5type InnerUint = U256;
7
8pub const ONE: u128 = 1_000_000_000_000;
10
11pub const BPS_EXPO: i32 = -4;
13
14#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
16pub struct PreciseNumber {
17 pub val: InnerUint,
19}
20
21fn one() -> InnerUint {
23 InnerUint::from(ONE)
24}
25
26fn zero() -> InnerUint {
28 InnerUint::from(0)
29}
30
31impl PreciseNumber {
32 fn rounding_correction() -> InnerUint {
36 InnerUint::from(ONE / 2)
37 }
38
39 fn precision() -> InnerUint {
44 InnerUint::from(100)
45 }
46
47 pub fn zero() -> Self {
48 Self { val: zero() }
49 }
50
51 pub fn one() -> Self {
52 Self { val: one() }
53 }
54
55 const MAX_APPROXIMATION_ITERATIONS: u128 = 100;
57
58 fn min_pow_base() -> InnerUint {
61 InnerUint::from(1)
62 }
63
64 fn max_pow_base() -> InnerUint {
70 InnerUint::from(2 * ONE)
71 }
72
73 pub fn new(val: u128) -> Result<Self> {
75 let val = InnerUint::from(val)
76 .checked_mul(one())
77 .ok_or(ParclMathErrorCode::IntegerOverflow)?;
78 Ok(Self { val })
79 }
80
81 pub fn from(val: u128) -> Self {
83 let val = InnerUint::from(val);
84 Self { val }
85 }
86
87 pub fn from_bps(bps: u16) -> Result<Self> {
89 Self::from_decimal(bps.into(), BPS_EXPO)
90 }
91
92 pub fn from_decimal(decimal: u128, exponent: i32) -> Result<Self> {
94 let precision_expo = 12 + exponent;
95 let mut precision = 10u128
96 .checked_pow(to_u32(precision_expo.abs())?)
97 .ok_or(ParclMathErrorCode::IntegerOverflow)?;
98 if precision_expo < 0 {
99 precision = ONE
100 .checked_div(precision)
101 .ok_or(ParclMathErrorCode::IntegerOverflow)?;
102 }
103 let val = InnerUint::from(
104 decimal
105 .checked_mul(precision)
106 .ok_or(ParclMathErrorCode::IntegerOverflow)?,
107 );
108 Ok(Self { val })
109 }
110
111 pub fn to_imprecise(&self) -> Result<u128> {
113 let val = self
114 .val
115 .checked_add(Self::rounding_correction())
116 .ok_or(ParclMathErrorCode::IntegerOverflow)?
117 .checked_div(one())
118 .ok_or(ParclMathErrorCode::IntegerOverflow)?
119 .as_u128();
120 Ok(val)
121 }
122
123 pub fn to_imprecise_u64(&self) -> Result<u64> {
125 let val = self
126 .val
127 .checked_add(Self::rounding_correction())
128 .ok_or(ParclMathErrorCode::IntegerOverflow)?
129 .checked_div(one())
130 .ok_or(ParclMathErrorCode::IntegerOverflow)?
131 .as_u64();
132 Ok(val)
133 }
134
135 pub fn mul_up(self, rhs: Self) -> Result<Self> {
136 Ok(Self::from(
137 self.val
138 .as_u128()
139 .checked_mul(rhs.val.as_u128())
140 .ok_or(ParclMathErrorCode::IntegerOverflow)?
141 .checked_add(
142 ONE.checked_sub(1)
143 .ok_or(ParclMathErrorCode::IntegerOverflow)?,
144 )
145 .ok_or(ParclMathErrorCode::IntegerOverflow)?
146 .checked_div(ONE)
147 .ok_or(ParclMathErrorCode::IntegerOverflow)?,
148 ))
149 }
150
151 pub fn div_up(self, rhs: Self) -> Result<Self> {
152 Ok(Self::from(
153 self.val
154 .as_u128()
155 .checked_mul(ONE)
156 .unwrap()
157 .checked_add(rhs.val.as_u128().checked_sub(1).unwrap())
158 .unwrap()
159 .checked_div(rhs.val.as_u128())
160 .unwrap(),
161 ))
162 }
163
164 pub fn almost_eq(&self, rhs: &Self, precision: InnerUint) -> bool {
166 let (difference, _) = self.unsigned_sub(rhs);
167 difference.val < precision
168 }
169
170 pub fn less_than(&self, rhs: &Self) -> bool {
172 self.val < rhs.val
173 }
174
175 pub fn greater_than(&self, rhs: &Self) -> bool {
177 self.val > rhs.val
178 }
179
180 pub fn less_than_or_equal(&self, rhs: &Self) -> bool {
182 self.val <= rhs.val
183 }
184
185 pub fn greater_than_or_equal(&self, rhs: &Self) -> bool {
187 self.val >= rhs.val
188 }
189
190 pub fn floor(&self) -> Result<Self> {
192 let one = one();
193 let val = self
194 .val
195 .checked_div(one)
196 .ok_or(ParclMathErrorCode::IntegerOverflow)?
197 .checked_mul(one)
198 .ok_or(ParclMathErrorCode::IntegerOverflow)?;
199 Ok(Self { val })
200 }
201
202 pub fn ceil(&self) -> Result<Self> {
204 let one = one();
205 let val = self
206 .val
207 .checked_add(
208 one.checked_sub(InnerUint::from(1))
209 .ok_or(ParclMathErrorCode::IntegerOverflow)?,
210 )
211 .ok_or(ParclMathErrorCode::IntegerOverflow)?
212 .checked_div(one)
213 .ok_or(ParclMathErrorCode::IntegerOverflow)?
214 .checked_mul(one)
215 .ok_or(ParclMathErrorCode::IntegerOverflow)?;
216 Ok(Self { val })
217 }
218
219 pub fn checked_div(&self, rhs: &Self) -> Result<Self> {
221 if *rhs == Self::zero() {
222 return Err(error!(ParclMathErrorCode::IntegerOverflow));
223 }
224 match self.val.checked_mul(one()) {
225 Some(v) => {
226 let val = v
227 .checked_add(Self::rounding_correction())
228 .ok_or(ParclMathErrorCode::IntegerOverflow)?
229 .checked_div(rhs.val)
230 .ok_or(ParclMathErrorCode::IntegerOverflow)?;
231 Ok(Self { val })
232 }
233 None => {
234 let val = self
235 .val
236 .checked_add(Self::rounding_correction())
237 .ok_or(ParclMathErrorCode::IntegerOverflow)?
238 .checked_div(rhs.val)
239 .ok_or(ParclMathErrorCode::IntegerOverflow)?
240 .checked_mul(one())
241 .ok_or(ParclMathErrorCode::IntegerOverflow)?;
242 Ok(Self { val })
243 }
244 }
245 }
246
247 pub fn checked_mul(&self, rhs: &Self) -> Result<Self> {
249 let one = one();
250 match self.val.checked_mul(rhs.val) {
251 Some(v) => {
252 let val = v
253 .checked_add(Self::rounding_correction())
254 .ok_or(ParclMathErrorCode::IntegerOverflow)?
255 .checked_div(one)
256 .ok_or(ParclMathErrorCode::IntegerOverflow)?;
257 Ok(Self { val })
258 }
259 None => {
260 let val = if self.val >= rhs.val {
261 self.val
262 .checked_div(one)
263 .ok_or(ParclMathErrorCode::IntegerOverflow)?
264 .checked_mul(rhs.val)
265 .ok_or(ParclMathErrorCode::IntegerOverflow)?
266 } else {
267 rhs.val
268 .checked_div(one)
269 .ok_or(ParclMathErrorCode::IntegerOverflow)?
270 .checked_mul(self.val)
271 .ok_or(ParclMathErrorCode::IntegerOverflow)?
272 };
273 Ok(Self { val })
274 }
275 }
276 }
277
278 pub fn checked_add(&self, rhs: &Self) -> Result<Self> {
280 let val = self
281 .val
282 .checked_add(rhs.val)
283 .ok_or(ParclMathErrorCode::IntegerOverflow)?;
284 Ok(Self { val })
285 }
286
287 pub fn checked_sub(&self, rhs: &Self) -> Result<Self> {
289 let val = self
290 .val
291 .checked_sub(rhs.val)
292 .ok_or(ParclMathErrorCode::IntegerOverflow)?;
293 Ok(Self { val })
294 }
295
296 pub fn unsigned_sub(&self, rhs: &Self) -> (Self, bool) {
298 match self.val.checked_sub(rhs.val) {
299 None => {
300 let val = rhs.val.checked_sub(self.val).unwrap();
301 (Self { val }, true)
302 }
303 Some(val) => (Self { val }, false),
304 }
305 }
306
307 pub fn checked_pow(&self, exponent: u128) -> Result<Self> {
309 let val = if exponent
312 .checked_rem(2)
313 .ok_or(ParclMathErrorCode::IntegerOverflow)?
314 == 0
315 {
316 one()
317 } else {
318 self.val
319 };
320 let mut result = Self { val };
321
322 let mut squared_base = *self;
326 let mut current_exponent = exponent
327 .checked_div(2)
328 .ok_or(ParclMathErrorCode::IntegerOverflow)?;
329 while current_exponent != 0 {
330 squared_base = squared_base.checked_mul(&squared_base)?;
331
332 if current_exponent
334 .checked_rem(2)
335 .ok_or(ParclMathErrorCode::IntegerOverflow)?
336 != 0
337 {
338 result = result.checked_mul(&squared_base)?;
339 }
340
341 current_exponent = current_exponent
342 .checked_div(2)
343 .ok_or(ParclMathErrorCode::IntegerOverflow)?;
344 }
345 Ok(result)
346 }
347
348 fn checked_pow_approximation(&self, exponent: &Self, max_iterations: u128) -> Result<Self> {
367 assert!(self.val >= Self::min_pow_base());
368 assert!(self.val <= Self::max_pow_base());
369 let one = Self::one();
370 if *exponent == Self::zero() {
371 return Ok(one);
372 }
373 let mut precise_guess = one;
374 let mut term = precise_guess;
375 let (x_minus_a, x_minus_a_negative) = self.unsigned_sub(&precise_guess);
376 let exponent_plus_one = exponent.checked_add(&one)?;
377 let mut negative = false;
378 for k in 1..max_iterations {
379 let k = Self::new(k)?;
380 let (current_exponent, current_exponent_negative) = exponent_plus_one.unsigned_sub(&k);
381 term = term.checked_mul(¤t_exponent)?;
382 term = term.checked_mul(&x_minus_a)?;
383 term = term.checked_div(&k)?;
384 if term.val < Self::precision() {
385 break;
386 }
387 if x_minus_a_negative {
388 negative = !negative;
389 }
390 if current_exponent_negative {
391 negative = !negative;
392 }
393 if negative {
394 precise_guess = precise_guess.checked_sub(&term)?;
395 } else {
396 precise_guess = precise_guess.checked_add(&term)?;
397 }
398 }
399 Ok(precise_guess)
400 }
401
402 #[allow(dead_code)]
407 fn checked_pow_fraction(&self, exponent: &Self) -> Result<Self> {
408 assert!(self.val >= Self::min_pow_base());
409 assert!(self.val <= Self::max_pow_base());
410 let whole_exponent = exponent.floor()?;
411 let precise_whole = self.checked_pow(whole_exponent.to_imprecise()?)?;
412 let (remainder_exponent, negative) = exponent.unsigned_sub(&whole_exponent);
413 assert!(!negative);
414 if remainder_exponent.val == InnerUint::from(0) {
415 return Ok(precise_whole);
416 }
417 let precise_remainder = self
418 .checked_pow_approximation(&remainder_exponent, Self::MAX_APPROXIMATION_ITERATIONS)?;
419 precise_whole.checked_mul(&precise_remainder)
420 }
421
422 fn newtonian_root_approximation(
427 &self,
428 root: &Self,
429 mut guess: Self,
430 iterations: u128,
431 ) -> Result<Self> {
432 let zero = Self::zero();
433 if *self == zero {
434 return Ok(zero);
435 }
436 if *root == zero {
437 return Err(error!(ParclMathErrorCode::IntegerOverflow));
438 }
439 let one = Self::new(1)?;
440 let root_minus_one = root.checked_sub(&one)?;
441 let root_minus_one_whole = root_minus_one.to_imprecise()?;
442 let mut last_guess = guess;
443 let precision = Self::precision();
444 for _ in 0..iterations {
445 let first_term = root_minus_one.checked_mul(&guess)?;
447 let power = guess.checked_pow(root_minus_one_whole);
448 let second_term = match power {
449 Ok(num) => self.checked_div(&num)?,
450 Err(_) => Self::new(0)?,
451 };
452 guess = first_term.checked_add(&second_term)?.checked_div(root)?;
453 if last_guess.almost_eq(&guess, precision) {
454 break;
455 } else {
456 last_guess = guess;
457 }
458 }
459 Ok(guess)
460 }
461
462 fn minimum_sqrt_base() -> Self {
465 Self {
466 val: InnerUint::from(0),
467 }
468 }
469
470 fn maximum_sqrt_base() -> Self {
473 Self::new(std::u128::MAX).unwrap()
474 }
475
476 pub fn sqrt(&self) -> Result<Self> {
479 if self.less_than(&Self::minimum_sqrt_base())
480 || self.greater_than(&Self::maximum_sqrt_base())
481 {
482 return Err(error!(ParclMathErrorCode::IntegerOverflow));
483 }
484 let two = PreciseNumber::new(2)?;
485 let one = PreciseNumber::new(1)?;
486 let guess = self.checked_add(&one)?.checked_div(&two)?;
489 self.newtonian_root_approximation(&two, guess, Self::MAX_APPROXIMATION_ITERATIONS)
490 }
491}
492
493impl Add<PreciseNumber> for PreciseNumber {
494 type Output = Self;
495 fn add(self, rhs: PreciseNumber) -> Self::Output {
496 self.checked_add(&rhs).unwrap()
497 }
498}
499
500impl Sub<PreciseNumber> for PreciseNumber {
501 type Output = Self;
502 fn sub(self, rhs: PreciseNumber) -> Self::Output {
503 self.checked_sub(&rhs).unwrap()
504 }
505}
506
507impl Mul<PreciseNumber> for PreciseNumber {
508 type Output = Self;
509 fn mul(self, rhs: PreciseNumber) -> Self::Output {
510 self.checked_mul(&rhs).unwrap()
511 }
512}
513
514impl Div<PreciseNumber> for PreciseNumber {
515 type Output = Self;
516 fn div(self, rhs: PreciseNumber) -> Self::Output {
517 self.checked_div(&rhs).unwrap()
518 }
519}
520
521impl AddAssign<PreciseNumber> for PreciseNumber {
522 fn add_assign(&mut self, rhs: PreciseNumber) {
523 self.val.add_assign(rhs.val)
524 }
525}
526
527impl SubAssign<PreciseNumber> for PreciseNumber {
528 fn sub_assign(&mut self, rhs: PreciseNumber) {
529 self.val.sub_assign(rhs.val)
530 }
531}
532
533impl MulAssign<PreciseNumber> for PreciseNumber {
534 fn mul_assign(&mut self, rhs: PreciseNumber) {
535 self.val.mul_assign(rhs.val);
536 self.val.div_assign(one());
537 }
538}
539
540impl DivAssign<PreciseNumber> for PreciseNumber {
541 fn div_assign(&mut self, rhs: PreciseNumber) {
542 self.val.mul_assign(one());
543 self.val.div_assign(rhs.val);
544 }
545}
546
547#[cfg(test)]
548mod tests {
549 use super::*;
550 use proptest::prelude::*;
551
552 fn check_pow_approximation(base: InnerUint, exponent: InnerUint, expected: InnerUint) {
553 let precision = InnerUint::from(5_000_000); let base = PreciseNumber { val: base };
555 let exponent = PreciseNumber { val: exponent };
556 let root = base
557 .checked_pow_approximation(&exponent, PreciseNumber::MAX_APPROXIMATION_ITERATIONS)
558 .unwrap();
559 let expected = PreciseNumber { val: expected };
560 assert!(root.almost_eq(&expected, precision));
561 }
562
563 #[test]
564 fn test_root_approximation() {
565 let one = one();
566 check_pow_approximation(one / 4, one / 2, one / 2); check_pow_approximation(one * 11 / 10, one / 2, InnerUint::from(1_048808848161u128)); check_pow_approximation(one * 4 / 5, one * 2 / 5, InnerUint::from(914610103850u128));
572 check_pow_approximation(one / 2, one * 4 / 50, InnerUint::from(946057646730u128));
576 }
578
579 fn check_pow_fraction(
580 base: InnerUint,
581 exponent: InnerUint,
582 expected: InnerUint,
583 precision: InnerUint,
584 ) {
585 let base = PreciseNumber { val: base };
586 let exponent = PreciseNumber { val: exponent };
587 let power = base.checked_pow_fraction(&exponent).unwrap();
588 let expected = PreciseNumber { val: expected };
589 assert!(power.almost_eq(&expected, precision));
590 }
591
592 #[test]
593 fn test_pow_fraction() {
594 let one = one();
595 let precision = InnerUint::from(50_000_000); let less_precision = precision * 1_000; check_pow_fraction(one, one, one, precision);
598 check_pow_fraction(
599 one * 20 / 13,
600 one * 50 / 3,
601 InnerUint::from(1312_534484739100u128),
602 precision,
603 ); check_pow_fraction(one * 2 / 7, one * 49 / 4, InnerUint::from(2163), precision);
605 check_pow_fraction(
606 one * 5000 / 5100,
607 one / 9,
608 InnerUint::from(997802126900u128),
609 precision,
610 ); check_pow_fraction(
614 one * 2,
615 one * 27 / 5,
616 InnerUint::from(42_224253144700u128),
617 less_precision,
618 ); check_pow_fraction(
620 one * 18 / 10,
621 one * 11 / 3,
622 InnerUint::from(8_629769290500u128),
623 less_precision,
624 ); }
626
627 #[test]
628 fn test_newtonian_approximation() {
629 let test = PreciseNumber::new(0).unwrap();
630 let nth_root = PreciseNumber::new(0).unwrap();
631 let guess = test.checked_div(&nth_root);
632 assert!(guess.is_err());
633
634 let test = PreciseNumber::new(9).unwrap();
636 let nth_root = PreciseNumber::new(2).unwrap();
637 let guess = test.checked_div(&nth_root).unwrap();
638 let root = test
639 .newtonian_root_approximation(
640 &nth_root,
641 guess,
642 PreciseNumber::MAX_APPROXIMATION_ITERATIONS,
643 )
644 .unwrap()
645 .to_imprecise()
646 .unwrap();
647 assert_eq!(root, 3); let test = PreciseNumber::new(101).unwrap();
650 let nth_root = PreciseNumber::new(2).unwrap();
651 let guess = test.checked_div(&nth_root).unwrap();
652 let root = test
653 .newtonian_root_approximation(
654 &nth_root,
655 guess,
656 PreciseNumber::MAX_APPROXIMATION_ITERATIONS,
657 )
658 .unwrap()
659 .to_imprecise()
660 .unwrap();
661 assert_eq!(root, 10); let test = PreciseNumber::new(1_000_000_000).unwrap();
664 let nth_root = PreciseNumber::new(2).unwrap();
665 let guess = test.checked_div(&nth_root).unwrap();
666 let root = test
667 .newtonian_root_approximation(
668 &nth_root,
669 guess,
670 PreciseNumber::MAX_APPROXIMATION_ITERATIONS,
671 )
672 .unwrap()
673 .to_imprecise()
674 .unwrap();
675 assert_eq!(root, 31_623); let test = PreciseNumber::new(500).unwrap();
679 let nth_root = PreciseNumber::new(5).unwrap();
680 let guess = test.checked_div(&nth_root).unwrap();
681 let root = test
682 .newtonian_root_approximation(
683 &nth_root,
684 guess,
685 PreciseNumber::MAX_APPROXIMATION_ITERATIONS,
686 )
687 .unwrap()
688 .to_imprecise()
689 .unwrap();
690 assert_eq!(root, 3); }
692
693 #[test]
694 fn test_checked_mul() {
695 let number_one = PreciseNumber::new(0).unwrap();
696 let number_two = PreciseNumber::new(0).unwrap();
697 let result = number_one.checked_mul(&number_two).unwrap();
698 assert_eq!(result, PreciseNumber { val: U256::from(0) });
699
700 let number_one = PreciseNumber::new(2).unwrap();
701 let number_two = PreciseNumber::new(2).unwrap();
702 let result = number_one.checked_mul(&number_two).unwrap();
703 assert_eq!(result, PreciseNumber::new(2 * 2).unwrap());
704
705 let number_one = PreciseNumber { val: U256::MAX };
706 let number_two = PreciseNumber::new(1).unwrap();
707 let result = number_one.checked_mul(&number_two).unwrap();
708 assert_eq!(result.val, U256::MAX / one() * one());
709
710 let number_one = PreciseNumber { val: U256::MAX };
711 let mut number_two = PreciseNumber::new(1).unwrap();
712 number_two.val += U256::from(1);
713 let result = number_one.checked_mul(&number_two);
714 assert!(result.is_err());
715 }
716
717 fn check_square_root(check: &PreciseNumber) {
718 let epsilon = PreciseNumber {
719 val: InnerUint::from(10),
720 }; let one = PreciseNumber::one();
722 let one_plus_epsilon = one.checked_add(&epsilon).unwrap();
723 let one_minus_epsilon = one.checked_sub(&epsilon).unwrap();
724 let approximate_root = check.sqrt().unwrap();
725 let lower_bound = approximate_root
726 .checked_mul(&one_minus_epsilon)
727 .unwrap()
728 .checked_pow(2)
729 .unwrap();
730 let upper_bound = approximate_root
731 .checked_mul(&one_plus_epsilon)
732 .unwrap()
733 .checked_pow(2)
734 .unwrap();
735 assert!(check.less_than_or_equal(&upper_bound));
736 assert!(check.greater_than_or_equal(&lower_bound));
737 }
738
739 #[test]
740 fn test_square_root_min_max() {
741 let test_roots = [
742 PreciseNumber::minimum_sqrt_base(),
743 PreciseNumber::maximum_sqrt_base(),
744 ];
745 for i in test_roots.iter() {
746 check_square_root(i);
747 }
748 }
749
750 #[test]
751 fn test_floor() {
752 let whole_number = PreciseNumber::new(2).unwrap();
753 let mut decimal_number = PreciseNumber::new(2).unwrap();
754 decimal_number.val += InnerUint::from(1);
755 let floor = decimal_number.floor().unwrap();
756 let floor_again = floor.floor().unwrap();
757 assert_eq!(whole_number.val, floor.val);
758 assert_eq!(whole_number.val, floor_again.val);
759 }
760
761 #[test]
762 fn test_ceiling() {
763 let whole_number = PreciseNumber::new(2).unwrap();
764 let mut decimal_number = PreciseNumber::new(2).unwrap();
765 decimal_number.val -= InnerUint::from(1);
766 let ceiling = decimal_number.ceil().unwrap();
767 let ceiling_again = ceiling.ceil().unwrap();
768 assert_eq!(whole_number.val, ceiling.val);
769 assert_eq!(whole_number.val, ceiling_again.val);
770 }
771
772 proptest! {
773 #[test]
774 fn test_square_root(a in 0..u128::MAX) {
775 let a = PreciseNumber { val: InnerUint::from(a) };
776 check_square_root(&a);
777 }
778 }
779}