dashu_int/modular/
pow.rs

1use crate::{
2    arch::word::Word,
3    div_const::{ConstDoubleDivisor, ConstSingleDivisor},
4    primitive::{split_dword, WORD_BITS},
5    repr::TypedReprRef::*,
6    ubig::UBig,
7};
8
9use super::repr::{Reduced, ReducedDword, ReducedRepr, ReducedWord};
10use num_modular::Reducer;
11
12impl<'a> Reduced<'a> {
13    /// Exponentiation.
14    ///
15    /// If you want use a negative exponent, you can first use [inv()][Self::inv] to
16    /// convert the base to its inverse, and then call this method.
17    ///
18    /// # Examples
19    ///
20    /// ```
21    /// # use dashu_int::{fast_div::ConstDivisor, UBig};
22    /// // A Mersenne prime.
23    /// let p = UBig::from(2u8).pow(607) - UBig::ONE;
24    /// let ring = ConstDivisor::new(p.clone());
25    /// // Fermat's little theorem: a^(p-1) = 1 (mod p)
26    /// let a = ring.reduce(123);
27    /// assert_eq!(a.pow(&(p - UBig::ONE)), ring.reduce(1));
28    /// ```
29    #[inline]
30    pub fn pow(&self, exp: &UBig) -> Reduced<'a> {
31        match self.repr() {
32            ReducedRepr::Single(raw, ring) => {
33                Reduced::from_single(single::pow(ring, *raw, exp), ring)
34            }
35            ReducedRepr::Double(raw, ring) => {
36                Reduced::from_double(double::pow(ring, *raw, exp), ring)
37            }
38            ReducedRepr::Large(raw, ring) => Reduced::from_large(large::pow(ring, raw, exp), ring),
39        }
40    }
41}
42
43macro_rules! impl_mod_pow_for_primitive {
44    ($ns:ident, $ring:ty, $raw:ident) => {
45        mod $ns {
46            use super::*;
47
48            #[inline]
49            pub(super) fn pow_word(ring: &$ring, raw: $raw, exp: Word) -> $raw {
50                match exp {
51                    0 => <$raw>::one(ring),
52                    1 => raw, // no-op
53                    2 => $raw(ring.0.sqr(raw.0)),
54                    _ => {
55                        let bits = WORD_BITS - 1 - exp.leading_zeros();
56                        pow_helper(ring, raw, raw, exp, bits)
57                    }
58                }
59            }
60
61            /// lhs^2^bits * rhs^exp[..bits] (in the modulo ring)
62            #[inline]
63            fn pow_helper(ring: &$ring, lhs: $raw, rhs: $raw, exp: Word, mut bits: u32) -> $raw {
64                let mut res = lhs;
65                while bits > 0 {
66                    res.0 = ring.0.sqr(res.0);
67                    bits -= 1;
68                    if exp & (1 << bits) != 0 {
69                        res.0 = ring.0.mul(&res.0, &rhs.0);
70                    }
71                }
72                res
73            }
74
75            /// Exponentiation.
76            #[inline]
77            pub(super) fn pow(ring: &$ring, raw: $raw, exp: &UBig) -> $raw {
78                match exp.repr() {
79                    RefSmall(dword) => {
80                        let (lo, hi) = split_dword(dword);
81                        if hi == 0 {
82                            pow_word(ring, raw, lo)
83                        } else {
84                            let res = pow_word(ring, raw, hi);
85                            pow_helper(ring, res, raw, lo, WORD_BITS)
86                        }
87                    }
88                    RefLarge(words) => pow_nontrivial(ring, raw, words),
89                }
90            }
91
92            fn pow_nontrivial(ring: &$ring, raw: $raw, exp_words: &[Word]) -> $raw {
93                let mut n = exp_words.len() - 1;
94                let mut res = pow_word(ring, raw, exp_words[n]); // apply the top word
95                while n != 0 {
96                    n -= 1;
97                    res = pow_helper(ring, res, raw, exp_words[n], WORD_BITS);
98                }
99                res
100            }
101        }
102    };
103}
104impl_mod_pow_for_primitive!(single, ConstSingleDivisor, ReducedWord);
105impl_mod_pow_for_primitive!(double, ConstDoubleDivisor, ReducedDword);
106
107mod large {
108    use dashu_base::BitTest;
109
110    use super::{
111        super::mul::{mul_memory_requirement, mul_normalized, sqr_in_place},
112        *,
113    };
114    use crate::{
115        div_const::ConstLargeDivisor,
116        error::panic_allocate_too_much,
117        math,
118        memory::{self, MemoryAllocation},
119        modular::repr::ReducedLarge,
120        primitive::{double_word, split_dword, PrimitiveUnsigned, WORD_BITS_USIZE},
121    };
122
123    pub(super) fn pow(ring: &ConstLargeDivisor, raw: &ReducedLarge, exp: &UBig) -> ReducedLarge {
124        if exp.is_zero() {
125            ReducedLarge::one(ring)
126        } else if exp.is_one() {
127            raw.clone()
128        } else {
129            pow_nontrivial(ring, raw, exp)
130        }
131    }
132
133    fn pow_nontrivial(ring: &ConstLargeDivisor, raw: &ReducedLarge, exp: &UBig) -> ReducedLarge {
134        let n = ring.normalized_divisor.len();
135        let window_len = choose_pow_window_len(exp.bit_len());
136
137        // Precomputed table of small odd powers up to 2^window_len, starting from raw^3.
138        #[allow(clippy::redundant_closure)]
139        let table_words = ((1usize << (window_len - 1)) - 1)
140            .checked_mul(n)
141            .unwrap_or_else(|| panic_allocate_too_much());
142
143        let memory_requirement = memory::add_layout(
144            memory::array_layout::<Word>(table_words),
145            mul_memory_requirement(ring),
146        );
147        let mut allocation = MemoryAllocation::new(memory_requirement);
148        let mut memory = allocation.memory();
149        let (table, mut memory) = memory.allocate_slice_fill::<Word>(table_words, 0);
150
151        // val = raw^2
152        let mut val = raw.clone();
153        sqr_in_place(ring, &mut val, &mut memory);
154
155        // raw^(2*i+1) = raw^(2*i-1) * val
156        for i in 1..(1 << (window_len - 1)) {
157            let (prev, cur) = if i == 1 {
158                (raw.0.as_ref(), &mut table[0..n])
159            } else {
160                let (prev, cur) = table[(i - 2) * n..i * n].split_at_mut(n);
161                (&*prev, cur)
162            };
163            cur.copy_from_slice(mul_normalized(ring, prev, &val.0, &mut memory));
164        }
165
166        let exp_words = exp.as_words();
167        // We already have raw^2 in val.
168        // exp.bit_len() >= 2 because exp >= 2.
169        let mut bit = exp.bit_len() - 2;
170
171        loop {
172            // val = raw ^ exp[bit..] ignoring the lowest bit
173            let word_idx = bit / WORD_BITS_USIZE;
174            let bit_idx = (bit % WORD_BITS_USIZE) as u32;
175            let cur_word = exp_words[word_idx];
176            if cur_word & (1 << bit_idx) != 0 {
177                let next_word = if word_idx == 0 {
178                    0
179                } else {
180                    exp_words[word_idx - 1]
181                };
182                // Get a window of window_len bits, with top bit of 1.
183                let (mut window, _) = split_dword(
184                    double_word(next_word, cur_word) >> (bit_idx + 1 + WORD_BITS - window_len),
185                );
186                window &= math::ones_word(window_len);
187                // Shift right to make the window odd.
188                let num_bits = window_len - window.trailing_zeros();
189                window >>= window_len - num_bits;
190                // val := val^2^(num_bits-1)
191                for _ in 0..num_bits - 1 {
192                    sqr_in_place(ring, &mut val, &mut memory);
193                }
194                bit -= (num_bits as usize) - 1;
195                // Now val = raw ^ exp[bit..] ignoring the num_bits lowest bits.
196                // val = val * raw^window from precomputed table.
197                debug_assert!(window & 1 == 1);
198                let entry_idx = (window >> 1) as usize;
199                let entry = if entry_idx == 0 {
200                    &raw.0
201                } else {
202                    &table[(entry_idx - 1) * n..entry_idx * n]
203                };
204                let prod = mul_normalized(ring, &val.0, entry, &mut memory);
205                val.0.copy_from_slice(prod);
206            }
207            // val = raw ^ exp[bit..]
208            if bit == 0 {
209                break;
210            }
211            bit -= 1;
212            sqr_in_place(ring, &mut val, &mut memory);
213        }
214        val
215    }
216
217    /// Choose the optimal window size for n-bit exponents.
218    /// 1 <= window_size < min(WORD_BITS, usize::BIT_SIZE) inclusive.
219    fn choose_pow_window_len(n: usize) -> u32 {
220        // This won't overflow because cost(3) is already approximately usize::MAX / 4
221        // and it can only grow by a factor of 2.
222        let cost = |window_size| (1usize << (window_size - 1)) - 1 + n / (window_size as usize + 1);
223        let mut window_size = 1;
224        let mut c = cost(window_size);
225        while window_size + 1 < WORD_BITS.min(usize::BIT_SIZE) {
226            let c2 = cost(window_size + 1);
227            if c <= c2 {
228                break;
229            }
230            window_size += 1;
231            c = c2;
232        }
233        window_size
234    }
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240
241    #[test]
242    fn test_pow_word() {
243        let ring = ConstSingleDivisor::new(100);
244        let modulo = ReducedWord(ring.0.transform(17));
245        assert_eq!(single::pow_word(&ring, modulo, 0).residue(&ring), 1);
246        assert_eq!(single::pow_word(&ring, modulo, 15).residue(&ring), 93);
247    }
248}