cosmwasm_std/math/
isqrt.rs1use core::{cmp, ops};
2
3use crate::{Uint128, Uint256, Uint512, Uint64};
4
5pub trait Isqrt {
8 #[must_use = "this returns the result of the operation, without modifying the original"]
10 fn isqrt(self) -> Self;
11}
12
13impl<I> Isqrt for I
14where
15 I: Unsigned
16 + ops::Add<I, Output = I>
17 + ops::Div<I, Output = I>
18 + ops::Shl<u32, Output = I>
19 + ops::Shr<u32, Output = I>
20 + cmp::PartialOrd
21 + Copy,
22{
23 fn isqrt(self) -> Self {
26 if self <= Self::ONE {
28 return self;
29 }
30
31 let mut x0 = Self::ONE << ((self.log_2() / 2) + 1);
32
33 if x0 > Self::ZERO {
34 let mut x1 = (x0 + self / x0) >> 1;
35
36 while x1 < x0 {
37 x0 = x1;
38 x1 = (x0 + self / x0) >> 1;
39 }
40
41 return x0;
42 }
43 self
44 }
45}
46
47pub trait Unsigned {
49 const ZERO: Self;
50 const ONE: Self;
51
52 fn log_2(self) -> u32;
53}
54
55macro_rules! impl_unsigned {
56 ($type:ty, $zero:expr, $one:expr) => {
57 impl Unsigned for $type {
58 const ZERO: Self = $zero;
59 const ONE: Self = $one;
60
61 fn log_2(self) -> u32 {
62 self.ilog2()
63 }
64 }
65 };
66}
67impl_unsigned!(u8, 0, 1);
68impl_unsigned!(u16, 0, 1);
69impl_unsigned!(u32, 0, 1);
70impl_unsigned!(u64, 0, 1);
71impl_unsigned!(u128, 0, 1);
72impl_unsigned!(usize, 0, 1);
73impl_unsigned!(Uint64, Self::zero(), Self::one());
74impl_unsigned!(Uint128, Self::zero(), Self::one());
75impl_unsigned!(Uint256, Self::zero(), Self::one());
76impl_unsigned!(Uint512, Self::zero(), Self::one());
77
78#[cfg(test)]
79mod tests {
80 use super::*;
81
82 #[test]
83 fn isqrt_primitives() {
84 assert_eq!(super::Isqrt::isqrt(0u8), 0);
86 assert_eq!(super::Isqrt::isqrt(1u8), 1);
87 assert_eq!(super::Isqrt::isqrt(24u8), 4);
88 assert_eq!(super::Isqrt::isqrt(25u8), 5);
89 assert_eq!(super::Isqrt::isqrt(26u8), 5);
90 assert_eq!(super::Isqrt::isqrt(36u8), 6);
91
92 assert_eq!(super::Isqrt::isqrt(26u8), 5);
94 assert_eq!(super::Isqrt::isqrt(26u16), 5);
95 assert_eq!(super::Isqrt::isqrt(26u32), 5);
96 assert_eq!(super::Isqrt::isqrt(26u64), 5);
97 assert_eq!(super::Isqrt::isqrt(26u128), 5);
98 }
99
100 #[test]
101 fn isqrt_uint64() {
102 assert_eq!(Uint64::new(24).isqrt(), Uint64::new(4));
103 }
104
105 #[test]
106 fn isqrt_uint128() {
107 assert_eq!(Uint128::new(24).isqrt(), Uint128::new(4));
108 }
109
110 #[test]
111 fn isqrt_uint256() {
112 assert_eq!(Uint256::from(24u32).isqrt(), Uint256::from(4u32));
113 assert_eq!(
114 (Uint256::from(u128::MAX) * Uint256::from(u128::MAX)).isqrt(),
115 Uint256::try_from("340282366920938463463374607431768211455").unwrap()
116 );
117 }
118
119 #[test]
120 fn isqrt_uint512() {
121 assert_eq!(Uint512::from(24u32).isqrt(), Uint512::from(4u32));
122 assert_eq!(
123 (Uint512::from(Uint256::MAX) * Uint512::from(Uint256::MAX)).isqrt(),
124 Uint512::try_from(
125 "115792089237316195423570985008687907853269984665640564039457584007913129639935"
126 )
127 .unwrap()
128 );
129 }
130}