Skip to main content

fixed_bigint/fixeduint/
roots_impl.rs

1use crate::fixeduint::FixedUInt;
2use crate::machineword::MachineWord;
3use crate::personality::Nct;
4use num_integer::Roots;
5use num_traits::{FromPrimitive, One, PrimInt, Zero};
6
7impl<T: MachineWord, const N: usize> Roots for FixedUInt<T, N, Nct> {
8    fn nth_root(&self, n: u32) -> Self {
9        if n == 0 {
10            panic!("nth_root: n must be non-zero");
11        }
12
13        if self.is_zero() {
14            return Self::zero();
15        }
16
17        if self.is_one() || n == 1 {
18            return *self;
19        }
20
21        let bit_len = self.bit_length();
22        if n > bit_len {
23            return Self::one();
24        }
25
26        // Initial guess: use ceiling(bit_len / n) for overestimate
27        let initial_exp = bit_len.div_ceil(n).max(1);
28        let mut x = Self::one() << (initial_exp as usize);
29
30        // Constants using FromPrimitive
31        let n_val = Self::from_u32(n).expect("n too large for FixedUInt");
32        let n_minus_1 = Self::from_u32(n - 1).expect("n too large for FixedUInt");
33
34        // Newton's method iteration
35        loop {
36            let x_pow_n_minus_1 = x.pow(n - 1);
37
38            if x_pow_n_minus_1.is_zero() {
39                break;
40            }
41
42            let quotient = *self / x_pow_n_minus_1;
43
44            let numerator = x * n_minus_1 + quotient;
45            let x_new = numerator / n_val;
46
47            if x_new >= x {
48                break;
49            }
50
51            x = x_new;
52        }
53
54        // Final adjustment to ensure r^n <= self < (r+1)^n
55        while x.pow(n) > *self {
56            x -= Self::one();
57        }
58
59        let mut x_plus_one = x + Self::one();
60        while x_plus_one.pow(n) <= *self {
61            x += Self::one();
62            x_plus_one = x + Self::one();
63        }
64
65        x
66    }
67}
68
69#[cfg(test)]
70mod tests {
71    use super::*;
72    use num_integer::Roots;
73    use num_traits::{One, PrimInt};
74
75    #[test]
76    fn test_sqrt_basic() {
77        type TestInt = FixedUInt<u32, 2>;
78
79        assert_eq!(TestInt::from(0u8).sqrt(), TestInt::from(0u8));
80        assert_eq!(TestInt::from(1u8).sqrt(), TestInt::from(1u8));
81        assert_eq!(TestInt::from(4u8).sqrt(), TestInt::from(2u8));
82        assert_eq!(TestInt::from(9u8).sqrt(), TestInt::from(3u8));
83        assert_eq!(TestInt::from(16u8).sqrt(), TestInt::from(4u8));
84        assert_eq!(TestInt::from(25u8).sqrt(), TestInt::from(5u8));
85
86        // Test non-perfect squares
87        assert_eq!(TestInt::from(2u8).sqrt(), TestInt::from(1u8));
88        assert_eq!(TestInt::from(3u8).sqrt(), TestInt::from(1u8));
89        assert_eq!(TestInt::from(8u8).sqrt(), TestInt::from(2u8));
90        assert_eq!(TestInt::from(15u8).sqrt(), TestInt::from(3u8));
91        assert_eq!(TestInt::from(24u8).sqrt(), TestInt::from(4u8));
92    }
93
94    #[test]
95    fn test_cbrt_basic() {
96        type TestInt = FixedUInt<u32, 2>;
97
98        assert_eq!(TestInt::from(0u8).cbrt(), TestInt::from(0u8));
99        assert_eq!(TestInt::from(1u8).cbrt(), TestInt::from(1u8));
100        assert_eq!(TestInt::from(8u8).cbrt(), TestInt::from(2u8));
101        assert_eq!(TestInt::from(27u8).cbrt(), TestInt::from(3u8));
102        assert_eq!(TestInt::from(64u8).cbrt(), TestInt::from(4u8));
103        assert_eq!(TestInt::from(125u8).cbrt(), TestInt::from(5u8));
104
105        // Test non-perfect cubes
106        assert_eq!(TestInt::from(2u8).cbrt(), TestInt::from(1u8));
107        assert_eq!(TestInt::from(7u8).cbrt(), TestInt::from(1u8));
108        assert_eq!(TestInt::from(26u8).cbrt(), TestInt::from(2u8));
109        assert_eq!(TestInt::from(63u8).cbrt(), TestInt::from(3u8));
110    }
111
112    #[test]
113    fn test_nth_root_basic() {
114        type TestInt = FixedUInt<u32, 2>;
115
116        // Test 4th roots
117        assert_eq!(TestInt::from(16u8).nth_root(4), TestInt::from(2u8));
118        assert_eq!(TestInt::from(81u8).nth_root(4), TestInt::from(3u8));
119        assert_eq!(TestInt::from(15u8).nth_root(4), TestInt::from(1u8));
120        assert_eq!(TestInt::from(80u8).nth_root(4), TestInt::from(2u8));
121
122        // Test 5th roots
123        assert_eq!(TestInt::from(32u8).nth_root(5), TestInt::from(2u8));
124        assert_eq!(TestInt::from(243u8).nth_root(5), TestInt::from(3u8));
125        assert_eq!(TestInt::from(31u8).nth_root(5), TestInt::from(1u8));
126
127        // Test n=1 (should return self)
128        assert_eq!(TestInt::from(42u8).nth_root(1), TestInt::from(42u8));
129        assert_eq!(TestInt::from(123u8).nth_root(1), TestInt::from(123u8));
130    }
131
132    #[test]
133    fn test_nth_root_edge_cases() {
134        type TestInt = FixedUInt<u32, 2>;
135
136        // Test with 0 and 1
137        assert_eq!(TestInt::from(0u8).nth_root(2), TestInt::from(0u8));
138        assert_eq!(TestInt::from(1u8).nth_root(2), TestInt::from(1u8));
139        assert_eq!(TestInt::from(0u8).nth_root(10), TestInt::from(0u8));
140        assert_eq!(TestInt::from(1u8).nth_root(10), TestInt::from(1u8));
141
142        // Test with large n (should return 1 for numbers > 1)
143        assert_eq!(TestInt::from(2u8).nth_root(100), TestInt::from(1u8));
144        assert_eq!(TestInt::from(1000u16).nth_root(50), TestInt::from(1u8));
145    }
146
147    #[test]
148    #[should_panic(expected = "nth_root: n must be non-zero")]
149    fn test_nth_root_zero_n() {
150        let x = FixedUInt::<u32, 2>::from(16u8);
151        x.nth_root(0);
152    }
153
154    #[test]
155    fn test_root_correctness() {
156        type TestInt = FixedUInt<u32, 2>;
157
158        // Test that r^n <= x < (r+1)^n for various cases
159        for x in 1..=100u16 {
160            let x_int = TestInt::from(x);
161
162            // Test square root
163            let sqrt_x = x_int.sqrt();
164            assert!(sqrt_x.pow(2) <= x_int);
165            assert!((sqrt_x + TestInt::one()).pow(2) > x_int);
166
167            // Test cube root
168            let cbrt_x = x_int.cbrt();
169            assert!(cbrt_x.pow(3) <= x_int);
170            assert!((cbrt_x + TestInt::one()).pow(3) > x_int);
171
172            // Test 4th root
173            let root4_x = x_int.nth_root(4);
174            assert!(root4_x.pow(4) <= x_int);
175            assert!((root4_x + TestInt::one()).pow(4) > x_int);
176        }
177    }
178}