1use crate::{
2 arch::word::Word,
3 div_const::{ConstDoubleDivisor, ConstSingleDivisor},
4 primitive::{split_dword, WORD_BITS},
5 repr::TypedReprRef::*,
6 ubig::UBig,
7};
8
9use super::repr::{Reduced, ReducedDword, ReducedRepr, ReducedWord};
10use num_modular::Reducer;
11
12impl<'a> Reduced<'a> {
13 #[inline]
30 pub fn pow(&self, exp: &UBig) -> Reduced<'a> {
31 match self.repr() {
32 ReducedRepr::Single(raw, ring) => {
33 Reduced::from_single(single::pow(ring, *raw, exp), ring)
34 }
35 ReducedRepr::Double(raw, ring) => {
36 Reduced::from_double(double::pow(ring, *raw, exp), ring)
37 }
38 ReducedRepr::Large(raw, ring) => Reduced::from_large(large::pow(ring, raw, exp), ring),
39 }
40 }
41}
42
43macro_rules! impl_mod_pow_for_primitive {
44 ($ns:ident, $ring:ty, $raw:ident) => {
45 mod $ns {
46 use super::*;
47
48 #[inline]
49 pub(super) fn pow_word(ring: &$ring, raw: $raw, exp: Word) -> $raw {
50 match exp {
51 0 => <$raw>::one(ring),
52 1 => raw, 2 => $raw(ring.0.sqr(raw.0)),
54 _ => {
55 let bits = WORD_BITS - 1 - exp.leading_zeros();
56 pow_helper(ring, raw, raw, exp, bits)
57 }
58 }
59 }
60
61 #[inline]
63 fn pow_helper(ring: &$ring, lhs: $raw, rhs: $raw, exp: Word, mut bits: u32) -> $raw {
64 let mut res = lhs;
65 while bits > 0 {
66 res.0 = ring.0.sqr(res.0);
67 bits -= 1;
68 if exp & (1 << bits) != 0 {
69 res.0 = ring.0.mul(&res.0, &rhs.0);
70 }
71 }
72 res
73 }
74
75 #[inline]
77 pub(super) fn pow(ring: &$ring, raw: $raw, exp: &UBig) -> $raw {
78 match exp.repr() {
79 RefSmall(dword) => {
80 let (lo, hi) = split_dword(dword);
81 if hi == 0 {
82 pow_word(ring, raw, lo)
83 } else {
84 let res = pow_word(ring, raw, hi);
85 pow_helper(ring, res, raw, lo, WORD_BITS)
86 }
87 }
88 RefLarge(words) => pow_nontrivial(ring, raw, words),
89 }
90 }
91
92 fn pow_nontrivial(ring: &$ring, raw: $raw, exp_words: &[Word]) -> $raw {
93 let mut n = exp_words.len() - 1;
94 let mut res = pow_word(ring, raw, exp_words[n]); while n != 0 {
96 n -= 1;
97 res = pow_helper(ring, res, raw, exp_words[n], WORD_BITS);
98 }
99 res
100 }
101 }
102 };
103}
104impl_mod_pow_for_primitive!(single, ConstSingleDivisor, ReducedWord);
105impl_mod_pow_for_primitive!(double, ConstDoubleDivisor, ReducedDword);
106
107mod large {
108 use dashu_base::BitTest;
109
110 use super::{
111 super::mul::{mul_memory_requirement, mul_normalized, sqr_in_place},
112 *,
113 };
114 use crate::{
115 div_const::ConstLargeDivisor,
116 error::panic_allocate_too_much,
117 math,
118 memory::{self, MemoryAllocation},
119 modular::repr::ReducedLarge,
120 primitive::{double_word, split_dword, PrimitiveUnsigned, WORD_BITS_USIZE},
121 };
122
123 pub(super) fn pow(ring: &ConstLargeDivisor, raw: &ReducedLarge, exp: &UBig) -> ReducedLarge {
124 if exp.is_zero() {
125 ReducedLarge::one(ring)
126 } else if exp.is_one() {
127 raw.clone()
128 } else {
129 pow_nontrivial(ring, raw, exp)
130 }
131 }
132
133 fn pow_nontrivial(ring: &ConstLargeDivisor, raw: &ReducedLarge, exp: &UBig) -> ReducedLarge {
134 let n = ring.normalized_divisor.len();
135 let window_len = choose_pow_window_len(exp.bit_len());
136
137 #[allow(clippy::redundant_closure)]
139 let table_words = ((1usize << (window_len - 1)) - 1)
140 .checked_mul(n)
141 .unwrap_or_else(|| panic_allocate_too_much());
142
143 let memory_requirement = memory::add_layout(
144 memory::array_layout::<Word>(table_words),
145 mul_memory_requirement(ring),
146 );
147 let mut allocation = MemoryAllocation::new(memory_requirement);
148 let mut memory = allocation.memory();
149 let (table, mut memory) = memory.allocate_slice_fill::<Word>(table_words, 0);
150
151 let mut val = raw.clone();
153 sqr_in_place(ring, &mut val, &mut memory);
154
155 for i in 1..(1 << (window_len - 1)) {
157 let (prev, cur) = if i == 1 {
158 (raw.0.as_ref(), &mut table[0..n])
159 } else {
160 let (prev, cur) = table[(i - 2) * n..i * n].split_at_mut(n);
161 (&*prev, cur)
162 };
163 cur.copy_from_slice(mul_normalized(ring, prev, &val.0, &mut memory));
164 }
165
166 let exp_words = exp.as_words();
167 let mut bit = exp.bit_len() - 2;
170
171 loop {
172 let word_idx = bit / WORD_BITS_USIZE;
174 let bit_idx = (bit % WORD_BITS_USIZE) as u32;
175 let cur_word = exp_words[word_idx];
176 if cur_word & (1 << bit_idx) != 0 {
177 let next_word = if word_idx == 0 {
178 0
179 } else {
180 exp_words[word_idx - 1]
181 };
182 let (mut window, _) = split_dword(
184 double_word(next_word, cur_word) >> (bit_idx + 1 + WORD_BITS - window_len),
185 );
186 window &= math::ones_word(window_len);
187 let num_bits = window_len - window.trailing_zeros();
189 window >>= window_len - num_bits;
190 for _ in 0..num_bits - 1 {
192 sqr_in_place(ring, &mut val, &mut memory);
193 }
194 bit -= (num_bits as usize) - 1;
195 debug_assert!(window & 1 == 1);
198 let entry_idx = (window >> 1) as usize;
199 let entry = if entry_idx == 0 {
200 &raw.0
201 } else {
202 &table[(entry_idx - 1) * n..entry_idx * n]
203 };
204 let prod = mul_normalized(ring, &val.0, entry, &mut memory);
205 val.0.copy_from_slice(prod);
206 }
207 if bit == 0 {
209 break;
210 }
211 bit -= 1;
212 sqr_in_place(ring, &mut val, &mut memory);
213 }
214 val
215 }
216
217 fn choose_pow_window_len(n: usize) -> u32 {
220 let cost = |window_size| (1usize << (window_size - 1)) - 1 + n / (window_size as usize + 1);
223 let mut window_size = 1;
224 let mut c = cost(window_size);
225 while window_size + 1 < WORD_BITS.min(usize::BIT_SIZE) {
226 let c2 = cost(window_size + 1);
227 if c <= c2 {
228 break;
229 }
230 window_size += 1;
231 c = c2;
232 }
233 window_size
234 }
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240
241 #[test]
242 fn test_pow_word() {
243 let ring = ConstSingleDivisor::new(100);
244 let modulo = ReducedWord(ring.0.transform(17));
245 assert_eq!(single::pow_word(&ring, modulo, 0).residue(&ring), 1);
246 assert_eq!(single::pow_word(&ring, modulo, 15).residue(&ring), 93);
247 }
248}