balancer_maths_rust/common/
oz_math.rs

1use num_bigint::BigInt;
2
3/// Computes the integer square root of a number using Newton's method
4/// Ported from OpenZeppelin's Solidity library to Rust
5/// @param a The input number (must be a non-negative integer)
6/// @returns The integer square root of a
7pub fn sqrt(a: &BigInt) -> BigInt {
8    // Handle edge cases when a is 0 or 1
9    if a <= &BigInt::from(1u64) {
10        return a.clone();
11    }
12
13    // Find an initial approximation using bit manipulation
14    // This approximation is close to 2^(log2(a)/2)
15    let mut aa = a.clone();
16    let mut xn = BigInt::from(1u64);
17
18    // Check if aa >= 2^128
19    let two_128 = BigInt::from(1u128) << 128;
20    if aa >= two_128 {
21        aa >>= 128;
22        xn <<= 64;
23    }
24
25    // Check if aa >= 2^64
26    let two_64 = BigInt::from(1u64) << 64;
27    if aa >= two_64 {
28        aa >>= 64;
29        xn <<= 32;
30    }
31
32    // Check if aa >= 2^32
33    let two_32 = BigInt::from(1u32) << 32;
34    if aa >= two_32 {
35        aa >>= 32;
36        xn <<= 16;
37    }
38
39    // Check if aa >= 2^16
40    let two_16 = BigInt::from(1u16) << 16;
41    if aa >= two_16 {
42        aa >>= 16;
43        xn <<= 8;
44    }
45
46    // Check if aa >= 2^8
47    let two_8 = BigInt::from(1u8) << 8;
48    if aa >= two_8 {
49        aa >>= 8;
50        xn <<= 4;
51    }
52
53    // Check if aa >= 2^4
54    let two_4 = BigInt::from(1u8) << 4;
55    if aa >= two_4 {
56        aa >>= 4;
57        xn <<= 2;
58    }
59
60    // Check if aa >= 2^2
61    let two_2 = BigInt::from(1u8) << 2;
62    if aa >= two_2 {
63        xn <<= 1;
64    }
65
66    // Refine the initial approximation
67    xn = (&xn * 3) >> 1;
68
69    // Apply Newton's method iterations
70    // Each iteration approximately doubles the number of correct bits
71    xn = (&xn + &(a / &xn)) >> 1;
72    xn = (&xn + &(a / &xn)) >> 1;
73    xn = (&xn + &(a / &xn)) >> 1;
74    xn = (&xn + &(a / &xn)) >> 1;
75    xn = (&xn + &(a / &xn)) >> 1;
76
77    // Final adjustment: if xn > sqrt(a), decrement by 1
78    if xn > (a / &xn) {
79        xn - 1
80    } else {
81        xn
82    }
83}