Skip to main content

crypto_bigint/uint/
root.rs

1//! Support for nth root calculation for [`Uint`].
2
3use core::num::NonZeroU32;
4
5use crate::{Limb, NonZero, Reciprocal, U64, Uint};
6
7impl<const LIMBS: usize> Uint<LIMBS> {
8    /// Computes `floor(self^(1/exp))`.
9    ///
10    /// Callers can check if `self` is an exact power of `exp` by exponentiating the result.
11    ///
12    /// This method is variable time in `self` and in the exponent.
13    #[must_use]
14    pub const fn floor_root_vartime(&self, exp: NonZeroU32) -> Self {
15        if self.is_zero_vartime() {
16            Self::ZERO
17        } else {
18            NonZero(*self).floor_root_vartime(exp).get_copy()
19        }
20    }
21
22    /// Compute the root `self^(1/exp)` returning an [`Option`] which `is_some`
23    /// only if the root is exact.
24    ///
25    /// This method is variable time in `self` and in the exponent.
26    pub fn checked_root_vartime(&self, exp: NonZeroU32) -> Option<Self> {
27        if self.is_zero_vartime() {
28            Some(Self::ZERO)
29        } else {
30            NonZero(*self).checked_root_vartime(exp).map(NonZero::get)
31        }
32    }
33}
34
35impl<const LIMBS: usize> NonZero<Uint<LIMBS>> {
36    /// Computes `floor(self^(1/exp))`.
37    ///
38    /// Callers can check if `self` is an exact power of `exp` by exponentiating the result.
39    ///
40    /// This method is variable time in self and in the exponent.
41    #[must_use]
42    pub const fn floor_root_vartime(&self, exp: NonZeroU32) -> Self {
43        // Uses Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.14.
44
45        if exp.get() == 1 {
46            return *self;
47        }
48        let exp_m1 = exp.get() - 1;
49        let exp_m1_limb = Limb::from_u32(exp_m1);
50        let exp_recip = Reciprocal::new(NonZero::<Limb>::from_u32(exp));
51
52        let rt_bits = self.0.bits().div_ceil(exp.get());
53        // The initial guess: `x_0 = 2^ceil(b/exp)`, where `exp^(b-1) <= self < exp^b`.
54        // Will not overflow since `b <= BITS`.
55        let mut x = Uint::<LIMBS>::ZERO.set_bit_vartime(rt_bits, true);
56        // Compute `self.0 / x_0^(exp-1)` by shifting.
57        let mut q = self.0.shr(rt_bits * exp_m1);
58
59        loop {
60            // Calculate `x_{i+1} = floor((x_i*(exp-1) + self / x_i^(1/(exp-1))) / exp)`, leaving `x` unmodified
61            // if it would increase.
62            let x2 = x
63                .wrapping_mul_limb(exp_m1_limb)
64                .wrapping_add(&q)
65                .div_rem_limb_with_reciprocal(&exp_recip)
66                .0;
67
68            // Terminate if `x_{i+1}` >= `x`.
69            if x2.cmp_vartime(&x).is_ge() {
70                return x.to_nz().expect_copied("ensured non-zero");
71            }
72            x = x2;
73
74            (q, _) = self.0.div_rem_vartime(
75                x.wrapping_pow_vartime(&U64::from_u32(exp_m1))
76                    .to_nz()
77                    .expect_ref("ensured non-zero"),
78            );
79        }
80    }
81
82    /// Compute the root `self^(1/exp)` returning an [`Option`] which `is_some`
83    /// only if the root is exact.
84    ///
85    /// This method is variable time in `self` and in the exponent.
86    #[must_use]
87    pub fn checked_root_vartime(&self, exp: NonZeroU32) -> Option<Self> {
88        let r = self.floor_root_vartime(exp);
89        let s = r.wrapping_pow_vartime(&U64::from_u32(exp.get()));
90        if self.cmp_vartime(&s).is_eq() {
91            Some(r)
92        } else {
93            None
94        }
95    }
96}
97
98#[cfg(test)]
99mod tests {
100    use crate::U256;
101    use core::num::NonZeroU32;
102
103    #[cfg(feature = "rand_core")]
104    use {
105        crate::{Limb, Random},
106        chacha20::ChaCha8Rng,
107        rand_core::SeedableRng,
108    };
109
110    #[test]
111    fn floor_root_vartime_expected() {
112        let three = NonZeroU32::new(3).unwrap();
113        assert_eq!(U256::from(0u8).floor_root_vartime(three), U256::from(0u8));
114        assert!(U256::from(0u8).checked_root_vartime(three).is_some());
115        assert_eq!(U256::from(1u8).floor_root_vartime(three), U256::from(1u8));
116        assert!(U256::from(1u8).checked_root_vartime(three).is_some());
117        assert_eq!(U256::from(2u8).floor_root_vartime(three), U256::from(1u8));
118        assert!(U256::from(2u8).checked_root_vartime(three).is_none());
119        assert_eq!(U256::from(8u8).floor_root_vartime(three), U256::from(2u8));
120        assert!(U256::from(8u8).checked_root_vartime(three).is_some());
121        assert_eq!(U256::from(9u8).floor_root_vartime(three), U256::from(2u8));
122        assert!(U256::from(9u8).checked_root_vartime(three).is_none());
123    }
124
125    #[cfg(feature = "rand_core")]
126    #[test]
127    fn fuzz() {
128        use crate::U64;
129        use core::num::NonZeroU32;
130
131        let mut rng: ChaCha8Rng = ChaCha8Rng::from_seed([7u8; 32]);
132
133        for _ in 0..50 {
134            let s = U256::random_from_rng(&mut rng);
135            let Some(s) = s.to_nz().into_option() else {
136                continue;
137            };
138            for exp in 1..10 {
139                let exp = NonZeroU32::new(exp).unwrap();
140                let exp_uint = U64::from_u32(exp.get());
141                let root = s.floor_root_vartime(exp);
142
143                // root is correct if rt^exp <= s and (rt+1)^exp > s
144                let s2 = root
145                    .checked_pow_vartime(&exp_uint)
146                    .expect("overflow, {s} exp={exp}, root={rt}");
147                assert!(s2 <= s.get(), "overflow, {s} exp={exp}, root={root}");
148                let rt_p1 = root.wrapping_add_limb(Limb::ONE);
149                let s3 = rt_p1.checked_pow_vartime(&exp_uint).into_option();
150                assert!(
151                    s3.is_none_or(|s3| s3 > s2),
152                    "underflow, {s} exp={exp}, root={root}"
153                );
154            }
155        }
156    }
157}