fixed_bigint/fixeduint/
roots_impl.rs1use crate::fixeduint::FixedUInt;
2use crate::machineword::MachineWord;
3use num_integer::Roots;
4use num_traits::{FromPrimitive, One, PrimInt, Zero};
5
6impl<T: MachineWord, const N: usize> Roots for FixedUInt<T, N> {
7 fn nth_root(&self, n: u32) -> Self {
8 if n == 0 {
9 panic!("nth_root: n must be non-zero");
10 }
11
12 if self.is_zero() {
13 return Self::zero();
14 }
15
16 if self.is_one() || n == 1 {
17 return *self;
18 }
19
20 let bit_len = self.bit_length();
21 if n > bit_len {
22 return Self::one();
23 }
24
25 let initial_exp = bit_len.div_ceil(n).max(1);
27 let mut x = Self::one() << (initial_exp as usize);
28
29 let n_val = Self::from_u32(n).expect("n too large for FixedUInt");
31 let n_minus_1 = Self::from_u32(n - 1).expect("n too large for FixedUInt");
32
33 loop {
35 let x_pow_n_minus_1 = x.pow(n - 1);
36
37 if x_pow_n_minus_1.is_zero() {
38 break;
39 }
40
41 let quotient = *self / x_pow_n_minus_1;
42
43 let numerator = x * n_minus_1 + quotient;
44 let x_new = numerator / n_val;
45
46 if x_new >= x {
47 break;
48 }
49
50 x = x_new;
51 }
52
53 while x.pow(n) > *self {
55 x -= Self::one();
56 }
57
58 let mut x_plus_one = x + Self::one();
59 while x_plus_one.pow(n) <= *self {
60 x += Self::one();
61 x_plus_one = x + Self::one();
62 }
63
64 x
65 }
66}
67
68#[cfg(test)]
69mod tests {
70 use super::*;
71 use num_integer::Roots;
72 use num_traits::{One, PrimInt};
73
74 #[test]
75 fn test_sqrt_basic() {
76 type TestInt = FixedUInt<u32, 2>;
77
78 assert_eq!(TestInt::from(0u8).sqrt(), TestInt::from(0u8));
79 assert_eq!(TestInt::from(1u8).sqrt(), TestInt::from(1u8));
80 assert_eq!(TestInt::from(4u8).sqrt(), TestInt::from(2u8));
81 assert_eq!(TestInt::from(9u8).sqrt(), TestInt::from(3u8));
82 assert_eq!(TestInt::from(16u8).sqrt(), TestInt::from(4u8));
83 assert_eq!(TestInt::from(25u8).sqrt(), TestInt::from(5u8));
84
85 assert_eq!(TestInt::from(2u8).sqrt(), TestInt::from(1u8));
87 assert_eq!(TestInt::from(3u8).sqrt(), TestInt::from(1u8));
88 assert_eq!(TestInt::from(8u8).sqrt(), TestInt::from(2u8));
89 assert_eq!(TestInt::from(15u8).sqrt(), TestInt::from(3u8));
90 assert_eq!(TestInt::from(24u8).sqrt(), TestInt::from(4u8));
91 }
92
93 #[test]
94 fn test_cbrt_basic() {
95 type TestInt = FixedUInt<u32, 2>;
96
97 assert_eq!(TestInt::from(0u8).cbrt(), TestInt::from(0u8));
98 assert_eq!(TestInt::from(1u8).cbrt(), TestInt::from(1u8));
99 assert_eq!(TestInt::from(8u8).cbrt(), TestInt::from(2u8));
100 assert_eq!(TestInt::from(27u8).cbrt(), TestInt::from(3u8));
101 assert_eq!(TestInt::from(64u8).cbrt(), TestInt::from(4u8));
102 assert_eq!(TestInt::from(125u8).cbrt(), TestInt::from(5u8));
103
104 assert_eq!(TestInt::from(2u8).cbrt(), TestInt::from(1u8));
106 assert_eq!(TestInt::from(7u8).cbrt(), TestInt::from(1u8));
107 assert_eq!(TestInt::from(26u8).cbrt(), TestInt::from(2u8));
108 assert_eq!(TestInt::from(63u8).cbrt(), TestInt::from(3u8));
109 }
110
111 #[test]
112 fn test_nth_root_basic() {
113 type TestInt = FixedUInt<u32, 2>;
114
115 assert_eq!(TestInt::from(16u8).nth_root(4), TestInt::from(2u8));
117 assert_eq!(TestInt::from(81u8).nth_root(4), TestInt::from(3u8));
118 assert_eq!(TestInt::from(15u8).nth_root(4), TestInt::from(1u8));
119 assert_eq!(TestInt::from(80u8).nth_root(4), TestInt::from(2u8));
120
121 assert_eq!(TestInt::from(32u8).nth_root(5), TestInt::from(2u8));
123 assert_eq!(TestInt::from(243u8).nth_root(5), TestInt::from(3u8));
124 assert_eq!(TestInt::from(31u8).nth_root(5), TestInt::from(1u8));
125
126 assert_eq!(TestInt::from(42u8).nth_root(1), TestInt::from(42u8));
128 assert_eq!(TestInt::from(123u8).nth_root(1), TestInt::from(123u8));
129 }
130
131 #[test]
132 fn test_nth_root_edge_cases() {
133 type TestInt = FixedUInt<u32, 2>;
134
135 assert_eq!(TestInt::from(0u8).nth_root(2), TestInt::from(0u8));
137 assert_eq!(TestInt::from(1u8).nth_root(2), TestInt::from(1u8));
138 assert_eq!(TestInt::from(0u8).nth_root(10), TestInt::from(0u8));
139 assert_eq!(TestInt::from(1u8).nth_root(10), TestInt::from(1u8));
140
141 assert_eq!(TestInt::from(2u8).nth_root(100), TestInt::from(1u8));
143 assert_eq!(TestInt::from(1000u16).nth_root(50), TestInt::from(1u8));
144 }
145
146 #[test]
147 #[should_panic(expected = "nth_root: n must be non-zero")]
148 fn test_nth_root_zero_n() {
149 let x = FixedUInt::<u32, 2>::from(16u8);
150 x.nth_root(0);
151 }
152
153 #[test]
154 fn test_root_correctness() {
155 type TestInt = FixedUInt<u32, 2>;
156
157 for x in 1..=100u16 {
159 let x_int = TestInt::from(x);
160
161 let sqrt_x = x_int.sqrt();
163 assert!(sqrt_x.pow(2) <= x_int);
164 assert!((sqrt_x + TestInt::one()).pow(2) > x_int);
165
166 let cbrt_x = x_int.cbrt();
168 assert!(cbrt_x.pow(3) <= x_int);
169 assert!((cbrt_x + TestInt::one()).pow(3) > x_int);
170
171 let root4_x = x_int.nth_root(4);
173 assert!(root4_x.pow(4) <= x_int);
174 assert!((root4_x + TestInt::one()).pow(4) > x_int);
175 }
176 }
177}