fixed_bigint/fixeduint/
roots_impl.rs1use crate::fixeduint::FixedUInt;
2use crate::machineword::MachineWord;
3use crate::personality::Nct;
4use num_integer::Roots;
5use num_traits::{FromPrimitive, One, PrimInt, Zero};
6
7impl<T: MachineWord, const N: usize> Roots for FixedUInt<T, N, Nct> {
8 fn nth_root(&self, n: u32) -> Self {
9 if n == 0 {
10 panic!("nth_root: n must be non-zero");
11 }
12
13 if self.is_zero() {
14 return Self::zero();
15 }
16
17 if self.is_one() || n == 1 {
18 return *self;
19 }
20
21 let bit_len = self.bit_length();
22 if n > bit_len {
23 return Self::one();
24 }
25
26 let initial_exp = bit_len.div_ceil(n).max(1);
28 let mut x = Self::one() << (initial_exp as usize);
29
30 let n_val = Self::from_u32(n).expect("n too large for FixedUInt");
32 let n_minus_1 = Self::from_u32(n - 1).expect("n too large for FixedUInt");
33
34 loop {
36 let x_pow_n_minus_1 = x.pow(n - 1);
37
38 if x_pow_n_minus_1.is_zero() {
39 break;
40 }
41
42 let quotient = *self / x_pow_n_minus_1;
43
44 let numerator = x * n_minus_1 + quotient;
45 let x_new = numerator / n_val;
46
47 if x_new >= x {
48 break;
49 }
50
51 x = x_new;
52 }
53
54 while x.pow(n) > *self {
56 x -= Self::one();
57 }
58
59 let mut x_plus_one = x + Self::one();
60 while x_plus_one.pow(n) <= *self {
61 x += Self::one();
62 x_plus_one = x + Self::one();
63 }
64
65 x
66 }
67}
68
69#[cfg(test)]
70mod tests {
71 use super::*;
72 use num_integer::Roots;
73 use num_traits::{One, PrimInt};
74
75 #[test]
76 fn test_sqrt_basic() {
77 type TestInt = FixedUInt<u32, 2>;
78
79 assert_eq!(TestInt::from(0u8).sqrt(), TestInt::from(0u8));
80 assert_eq!(TestInt::from(1u8).sqrt(), TestInt::from(1u8));
81 assert_eq!(TestInt::from(4u8).sqrt(), TestInt::from(2u8));
82 assert_eq!(TestInt::from(9u8).sqrt(), TestInt::from(3u8));
83 assert_eq!(TestInt::from(16u8).sqrt(), TestInt::from(4u8));
84 assert_eq!(TestInt::from(25u8).sqrt(), TestInt::from(5u8));
85
86 assert_eq!(TestInt::from(2u8).sqrt(), TestInt::from(1u8));
88 assert_eq!(TestInt::from(3u8).sqrt(), TestInt::from(1u8));
89 assert_eq!(TestInt::from(8u8).sqrt(), TestInt::from(2u8));
90 assert_eq!(TestInt::from(15u8).sqrt(), TestInt::from(3u8));
91 assert_eq!(TestInt::from(24u8).sqrt(), TestInt::from(4u8));
92 }
93
94 #[test]
95 fn test_cbrt_basic() {
96 type TestInt = FixedUInt<u32, 2>;
97
98 assert_eq!(TestInt::from(0u8).cbrt(), TestInt::from(0u8));
99 assert_eq!(TestInt::from(1u8).cbrt(), TestInt::from(1u8));
100 assert_eq!(TestInt::from(8u8).cbrt(), TestInt::from(2u8));
101 assert_eq!(TestInt::from(27u8).cbrt(), TestInt::from(3u8));
102 assert_eq!(TestInt::from(64u8).cbrt(), TestInt::from(4u8));
103 assert_eq!(TestInt::from(125u8).cbrt(), TestInt::from(5u8));
104
105 assert_eq!(TestInt::from(2u8).cbrt(), TestInt::from(1u8));
107 assert_eq!(TestInt::from(7u8).cbrt(), TestInt::from(1u8));
108 assert_eq!(TestInt::from(26u8).cbrt(), TestInt::from(2u8));
109 assert_eq!(TestInt::from(63u8).cbrt(), TestInt::from(3u8));
110 }
111
112 #[test]
113 fn test_nth_root_basic() {
114 type TestInt = FixedUInt<u32, 2>;
115
116 assert_eq!(TestInt::from(16u8).nth_root(4), TestInt::from(2u8));
118 assert_eq!(TestInt::from(81u8).nth_root(4), TestInt::from(3u8));
119 assert_eq!(TestInt::from(15u8).nth_root(4), TestInt::from(1u8));
120 assert_eq!(TestInt::from(80u8).nth_root(4), TestInt::from(2u8));
121
122 assert_eq!(TestInt::from(32u8).nth_root(5), TestInt::from(2u8));
124 assert_eq!(TestInt::from(243u8).nth_root(5), TestInt::from(3u8));
125 assert_eq!(TestInt::from(31u8).nth_root(5), TestInt::from(1u8));
126
127 assert_eq!(TestInt::from(42u8).nth_root(1), TestInt::from(42u8));
129 assert_eq!(TestInt::from(123u8).nth_root(1), TestInt::from(123u8));
130 }
131
132 #[test]
133 fn test_nth_root_edge_cases() {
134 type TestInt = FixedUInt<u32, 2>;
135
136 assert_eq!(TestInt::from(0u8).nth_root(2), TestInt::from(0u8));
138 assert_eq!(TestInt::from(1u8).nth_root(2), TestInt::from(1u8));
139 assert_eq!(TestInt::from(0u8).nth_root(10), TestInt::from(0u8));
140 assert_eq!(TestInt::from(1u8).nth_root(10), TestInt::from(1u8));
141
142 assert_eq!(TestInt::from(2u8).nth_root(100), TestInt::from(1u8));
144 assert_eq!(TestInt::from(1000u16).nth_root(50), TestInt::from(1u8));
145 }
146
147 #[test]
148 #[should_panic(expected = "nth_root: n must be non-zero")]
149 fn test_nth_root_zero_n() {
150 let x = FixedUInt::<u32, 2>::from(16u8);
151 x.nth_root(0);
152 }
153
154 #[test]
155 fn test_root_correctness() {
156 type TestInt = FixedUInt<u32, 2>;
157
158 for x in 1..=100u16 {
160 let x_int = TestInt::from(x);
161
162 let sqrt_x = x_int.sqrt();
164 assert!(sqrt_x.pow(2) <= x_int);
165 assert!((sqrt_x + TestInt::one()).pow(2) > x_int);
166
167 let cbrt_x = x_int.cbrt();
169 assert!(cbrt_x.pow(3) <= x_int);
170 assert!((cbrt_x + TestInt::one()).pow(3) > x_int);
171
172 let root4_x = x_int.nth_root(4);
174 assert!(root4_x.pow(4) <= x_int);
175 assert!((root4_x + TestInt::one()).pow(4) > x_int);
176 }
177 }
178}