1use crate::ring::*;
2use crate::integer::*;
3use crate::ordered::OrderedRingStore;
4use crate::primitive_int::*;
5
6pub fn generic_abs_square_and_multiply<T, U, F, H, I>(base: U, power: &El<I>, int_ring: I, mut square: F, mut multiply_base: H, identity: T) -> T
35 where I: RingStore,
36 I::Type: IntegerRing,
37 F: FnMut(T) -> T,
38 H: FnMut(&U, T) -> T
39{
40 try_generic_abs_square_and_multiply(base, power, int_ring, |a| Ok(square(a)), |a, b| Ok(multiply_base(a, b)), identity).unwrap_or_else(|x| x)
41}
42
43#[stability::unstable(feature = "enable")]
50pub fn try_generic_abs_square_and_multiply<T, U, F, H, I, E>(base: U, power: &El<I>, int_ring: I, mut square: F, mut multiply_base: H, identity: T) -> Result<T, E>
51 where I: RingStore,
52 I::Type: IntegerRing,
53 F: FnMut(T) -> Result<T, E>,
54 H: FnMut(&U, T) -> Result<T, E>
55{
56 if int_ring.is_zero(&power) {
57 return Ok(identity);
58 } else if int_ring.is_one(&power) {
59 return multiply_base(&base, identity);
60 }
61
62 let mut result = identity;
63 for i in (0..=int_ring.abs_highest_set_bit(power).unwrap()).rev() {
64 if int_ring.abs_is_bit_set(power, i) {
65 result = multiply_base(&base, square(result)?)?;
66 } else {
67 result = square(result)?;
68 }
69 }
70 return Ok(result);
71}
72
73#[stability::unstable(feature = "enable")]
83pub fn generic_pow_shortest_chain_table<T, F, G, H, I, E>(base: T, power: &El<I>, int_ring: I, mut double: G, mut mul: F, mut clone: H, identity: T) -> Result<T, E>
84 where I: RingStore,
85 I::Type: IntegerRing,
86 F: FnMut(&T, &T) -> Result<T, E>,
87 G: FnMut(&T) -> Result<T, E>,
88 H: FnMut(&T) -> T
89{
90 assert!(!int_ring.is_neg(power));
91 if int_ring.is_zero(&power) {
92 return Ok(identity);
93 } else if int_ring.is_one(&power) {
94 return Ok(base);
95 }
96
97 let mut mult_count = 0;
98
99 const LOG2_BOUND: usize = 6;
100 const BOUND: usize = 1 << LOG2_BOUND;
101 assert!(SHORTEST_ADDITION_CHAINS.len() > BOUND);
102 let mut table = Vec::with_capacity(BOUND);
103 table.resize_with(BOUND + 1, || None);
104 table[0] = Some(identity);
105 table[1] = Some(base);
106
107 #[inline(always)]
108 fn eval_power_using_table<T, F, G, E>(power: usize, mul: &mut F, double: &mut G, table: &mut Vec<Option<T>>, mult_count: &mut usize) -> Result<(), E>
109 where F: FnMut(&T, &T) -> Result<T, E>,
110 G: FnMut(&T) -> Result<T, E>,
111 {
112 if table[power].is_none() {
113 let (i, j) = SHORTEST_ADDITION_CHAINS[power];
114 eval_power_using_table(i, mul, double, table, mult_count)?;
115 eval_power_using_table(j, mul, double, table, mult_count)?;
116 if i == j {
117 *mult_count += 1;
118 table[power] = Some(double(table[i].as_ref().unwrap())?);
119 } else {
120 *mult_count += 1;
121 table[power] = Some(mul(table[i].as_ref().unwrap(), table[j].as_ref().unwrap())?);
122 }
123 }
124 return Ok(());
125 }
126
127 let bitlen = int_ring.abs_highest_set_bit(power).unwrap() + 1;
128 if bitlen < LOG2_BOUND {
129 let power = int_cast(int_ring.clone_el(&power), StaticRing::<i32>::RING, &int_ring) as usize;
130 eval_power_using_table(power, &mut mul, &mut double, &mut table, &mut mult_count)?;
131 return Ok(table.into_iter().nth(power).unwrap().unwrap());
132 }
133
134 let start_power = (0..LOG2_BOUND).filter(|j| int_ring.abs_is_bit_set(power, *j + bitlen - LOG2_BOUND)).map(|j| 1 << j).sum::<usize>();
135 eval_power_using_table(start_power, &mut mul, &mut double, &mut table, &mut mult_count)?;
136 let mut current = clone(table[start_power].as_ref().unwrap());
137
138 for i in (0..=(bitlen - LOG2_BOUND)).rev().step_by(LOG2_BOUND).skip(1) {
139 for _ in 0..LOG2_BOUND {
140 current = double(¤t)?;
141 mult_count += 1;
142 }
143 let local_power = (0..LOG2_BOUND).filter(|j| int_ring.abs_is_bit_set(power, *j + i)).map(|j| 1 << j).sum::<usize>();
144 if local_power != 0 {
145 eval_power_using_table(local_power, &mut mul, &mut double, &mut table, &mut mult_count)?;
146 current = mul(¤t, table[local_power].as_ref().unwrap())?;
147 mult_count += 1;
148 }
149 }
150
151 if bitlen % LOG2_BOUND != 0 {
152 let final_power = (0..(bitlen % LOG2_BOUND)).filter(|j| int_ring.abs_is_bit_set(power, *j)).map(|j| 1 << j).sum::<usize>();
153 eval_power_using_table(final_power, &mut mul, &mut double, &mut table, &mut mult_count)?;
154
155 for _ in 0..(bitlen % LOG2_BOUND) {
156 current = double(¤t)?;
157 mult_count += 1;
158 }
159 if final_power != 0 {
160 current = mul(¤t, table[final_power].as_ref().unwrap())?;
161 mult_count += 1;
162 }
163 }
164
165 debug_assert!(mult_count <= bitlen * 2);
166
167 return Ok(current);
168}
169
170const SHORTEST_ADDITION_CHAINS: [(usize, usize); 65] = [
175 (0, 0),
176 (1, 0),
177 (1, 1),
178 (2, 1),
179 (2, 2),
180 (3, 2),
181 (3, 3),
182 (5, 2),
183 (4, 4),
184 (8, 1),
185 (5, 5),
186 (10, 1),
187 (6, 6),
188 (9, 4),
189 (7, 7),
190 (12, 3),
191 (8, 8),
192 (9, 8),
193 (16, 2),
194 (18, 1),
195 (10, 10),
196 (15, 6),
197 (11, 11),
198 (20, 3),
199 (12, 12),
200 (17, 8),
201 (13, 13),
202 (24, 3),
203 (14, 14),
204 (25, 4),
205 (15, 15),
206 (28, 3),
207 (16, 16),
208 (32, 1),
209 (17, 17),
210 (26, 9),
211 (18, 18),
212 (36, 1),
213 (19, 19),
214 (27, 12),
215 (20, 20),
216 (40, 1),
217 (21, 21),
218 (34, 9),
219 (22, 22),
220 (30, 15),
221 (23, 23),
222 (46, 1),
223 (24, 24),
224 (33, 16),
225 (25, 25),
226 (48, 3),
227 (26, 26),
228 (37, 16),
229 (27, 27),
230 (54, 1),
231 (28, 28),
232 (49, 8),
233 (29, 29),
234 (56, 3),
235 (30, 30),
236 (52, 9),
237 (31, 31),
238 (51, 12),
239 (32, 32)
240];
241
242#[cfg(test)]
243use test::Bencher;
244#[cfg(test)]
245use crate::rings::zn::zn_64;
246#[cfg(test)]
247use crate::homomorphism::*;
248
249#[test]
250fn test_generic_abs_square_and_multiply() {
251 for i in 0..(1 << 16) {
252 assert_eq!(Ok(i), try_generic_abs_square_and_multiply::<_, _, _, _, _, !>(1, &i, StaticRing::<i32>::RING, |a| Ok(a * 2), |a, b| Ok(a + b), 0));
253 }
254}
255
256#[test]
257fn test_generic_pow_shortest_chain_table() {
258 for i in 0..(1 << 16) {
259 assert_eq!(Ok(i), generic_pow_shortest_chain_table::<_, _, _, _, _, !>(1, &i, StaticRing::<i32>::RING, |a| Ok(a * 2), |a, b| Ok(a + b), |a| *a, 0));
260 }
261}
262
263#[test]
264fn test_shortest_addition_chain_table() {
265 for i in 0..SHORTEST_ADDITION_CHAINS.len() {
266 assert_eq!(i, SHORTEST_ADDITION_CHAINS[i].0 + SHORTEST_ADDITION_CHAINS[i].1);
267 }
268}
269
270#[bench]
271fn bench_standard_square_and_multiply(bencher: &mut Bencher) {
272 let ring = zn_64::Zn::new(536903681);
273 let x = ring.int_hom().map(2);
274 bencher.iter(|| {
275 assert_el_eq!(&ring, &ring.one(), try_generic_abs_square_and_multiply::<_, _, _, _, _, !>(
276 &x,
277 &536903680,
278 StaticRing::<i64>::RING,
279 |mut res| {
280 ring.square(&mut res);
281 return Ok(res);
282 },
283 |a, b| Ok(ring.mul_ref_fst(a, b)),
284 ring.one()
285 ).unwrap());
286 });
287}
288
289#[bench]
290fn bench_addchain_square_and_multiply(bencher: &mut Bencher) {
291 let ring = zn_64::Zn::new(536903681);
292 let x = ring.int_hom().map(2);
293 bencher.iter(|| {
294 assert_el_eq!(&ring, &ring.one(), generic_pow_shortest_chain_table::<_, _, _, _, _, !>(
295 x,
296 &536903680,
297 StaticRing::<i64>::RING,
298 |a| {
299 let mut res = ring.clone_el(a);
300 ring.square(&mut res);
301 return Ok(res);
302 },
303 |a, b| Ok(ring.mul_ref(a, b)),
304 |a| ring.clone_el(a),
305 ring.one()
306 ).unwrap());
307 });
308}