fixed_bigint/fixeduint/
roots_impl.rs

1use crate::fixeduint::FixedUInt;
2use crate::machineword::MachineWord;
3use num_integer::Roots;
4use num_traits::{FromPrimitive, One, PrimInt, Zero};
5
6impl<T: MachineWord, const N: usize> Roots for FixedUInt<T, N> {
7    fn nth_root(&self, n: u32) -> Self {
8        if n == 0 {
9            panic!("nth_root: n must be non-zero");
10        }
11
12        if self.is_zero() {
13            return Self::zero();
14        }
15
16        if self.is_one() || n == 1 {
17            return *self;
18        }
19
20        let bit_len = self.bit_length();
21        if n > bit_len {
22            return Self::one();
23        }
24
25        // Initial guess: use ceiling(bit_len / n) for overestimate
26        let initial_exp = bit_len.div_ceil(n).max(1);
27        let mut x = Self::one() << (initial_exp as usize);
28
29        // Constants using FromPrimitive
30        let n_val = Self::from_u32(n).expect("n too large for FixedUInt");
31        let n_minus_1 = Self::from_u32(n - 1).expect("n too large for FixedUInt");
32
33        // Newton's method iteration
34        loop {
35            let x_pow_n_minus_1 = x.pow(n - 1);
36
37            if x_pow_n_minus_1.is_zero() {
38                break;
39            }
40
41            let quotient = *self / x_pow_n_minus_1;
42
43            let numerator = x * n_minus_1 + quotient;
44            let x_new = numerator / n_val;
45
46            if x_new >= x {
47                break;
48            }
49
50            x = x_new;
51        }
52
53        // Final adjustment to ensure r^n <= self < (r+1)^n
54        while x.pow(n) > *self {
55            x -= Self::one();
56        }
57
58        let mut x_plus_one = x + Self::one();
59        while x_plus_one.pow(n) <= *self {
60            x += Self::one();
61            x_plus_one = x + Self::one();
62        }
63
64        x
65    }
66}
67
68#[cfg(test)]
69mod tests {
70    use super::*;
71    use num_integer::Roots;
72    use num_traits::{One, PrimInt};
73
74    #[test]
75    fn test_sqrt_basic() {
76        type TestInt = FixedUInt<u32, 2>;
77
78        assert_eq!(TestInt::from(0u8).sqrt(), TestInt::from(0u8));
79        assert_eq!(TestInt::from(1u8).sqrt(), TestInt::from(1u8));
80        assert_eq!(TestInt::from(4u8).sqrt(), TestInt::from(2u8));
81        assert_eq!(TestInt::from(9u8).sqrt(), TestInt::from(3u8));
82        assert_eq!(TestInt::from(16u8).sqrt(), TestInt::from(4u8));
83        assert_eq!(TestInt::from(25u8).sqrt(), TestInt::from(5u8));
84
85        // Test non-perfect squares
86        assert_eq!(TestInt::from(2u8).sqrt(), TestInt::from(1u8));
87        assert_eq!(TestInt::from(3u8).sqrt(), TestInt::from(1u8));
88        assert_eq!(TestInt::from(8u8).sqrt(), TestInt::from(2u8));
89        assert_eq!(TestInt::from(15u8).sqrt(), TestInt::from(3u8));
90        assert_eq!(TestInt::from(24u8).sqrt(), TestInt::from(4u8));
91    }
92
93    #[test]
94    fn test_cbrt_basic() {
95        type TestInt = FixedUInt<u32, 2>;
96
97        assert_eq!(TestInt::from(0u8).cbrt(), TestInt::from(0u8));
98        assert_eq!(TestInt::from(1u8).cbrt(), TestInt::from(1u8));
99        assert_eq!(TestInt::from(8u8).cbrt(), TestInt::from(2u8));
100        assert_eq!(TestInt::from(27u8).cbrt(), TestInt::from(3u8));
101        assert_eq!(TestInt::from(64u8).cbrt(), TestInt::from(4u8));
102        assert_eq!(TestInt::from(125u8).cbrt(), TestInt::from(5u8));
103
104        // Test non-perfect cubes
105        assert_eq!(TestInt::from(2u8).cbrt(), TestInt::from(1u8));
106        assert_eq!(TestInt::from(7u8).cbrt(), TestInt::from(1u8));
107        assert_eq!(TestInt::from(26u8).cbrt(), TestInt::from(2u8));
108        assert_eq!(TestInt::from(63u8).cbrt(), TestInt::from(3u8));
109    }
110
111    #[test]
112    fn test_nth_root_basic() {
113        type TestInt = FixedUInt<u32, 2>;
114
115        // Test 4th roots
116        assert_eq!(TestInt::from(16u8).nth_root(4), TestInt::from(2u8));
117        assert_eq!(TestInt::from(81u8).nth_root(4), TestInt::from(3u8));
118        assert_eq!(TestInt::from(15u8).nth_root(4), TestInt::from(1u8));
119        assert_eq!(TestInt::from(80u8).nth_root(4), TestInt::from(2u8));
120
121        // Test 5th roots
122        assert_eq!(TestInt::from(32u8).nth_root(5), TestInt::from(2u8));
123        assert_eq!(TestInt::from(243u8).nth_root(5), TestInt::from(3u8));
124        assert_eq!(TestInt::from(31u8).nth_root(5), TestInt::from(1u8));
125
126        // Test n=1 (should return self)
127        assert_eq!(TestInt::from(42u8).nth_root(1), TestInt::from(42u8));
128        assert_eq!(TestInt::from(123u8).nth_root(1), TestInt::from(123u8));
129    }
130
131    #[test]
132    fn test_nth_root_edge_cases() {
133        type TestInt = FixedUInt<u32, 2>;
134
135        // Test with 0 and 1
136        assert_eq!(TestInt::from(0u8).nth_root(2), TestInt::from(0u8));
137        assert_eq!(TestInt::from(1u8).nth_root(2), TestInt::from(1u8));
138        assert_eq!(TestInt::from(0u8).nth_root(10), TestInt::from(0u8));
139        assert_eq!(TestInt::from(1u8).nth_root(10), TestInt::from(1u8));
140
141        // Test with large n (should return 1 for numbers > 1)
142        assert_eq!(TestInt::from(2u8).nth_root(100), TestInt::from(1u8));
143        assert_eq!(TestInt::from(1000u16).nth_root(50), TestInt::from(1u8));
144    }
145
146    #[test]
147    #[should_panic(expected = "nth_root: n must be non-zero")]
148    fn test_nth_root_zero_n() {
149        let x = FixedUInt::<u32, 2>::from(16u8);
150        x.nth_root(0);
151    }
152
153    #[test]
154    fn test_root_correctness() {
155        type TestInt = FixedUInt<u32, 2>;
156
157        // Test that r^n <= x < (r+1)^n for various cases
158        for x in 1..=100u16 {
159            let x_int = TestInt::from(x);
160
161            // Test square root
162            let sqrt_x = x_int.sqrt();
163            assert!(sqrt_x.pow(2) <= x_int);
164            assert!((sqrt_x + TestInt::one()).pow(2) > x_int);
165
166            // Test cube root
167            let cbrt_x = x_int.cbrt();
168            assert!(cbrt_x.pow(3) <= x_int);
169            assert!((cbrt_x + TestInt::one()).pow(3) > x_int);
170
171            // Test 4th root
172            let root4_x = x_int.nth_root(4);
173            assert!(root4_x.pow(4) <= x_int);
174            assert!((root4_x + TestInt::one()).pow(4) > x_int);
175        }
176    }
177}