feanor_math/algorithms/
sqr_mul.rs

1use crate::ring::*;
2use crate::integer::*;
3use crate::ordered::OrderedRingStore;
4use crate::primitive_int::*;
5
6///
7/// Uses the square-and-multiply technique to compute the reduction of `power` times `base`
8/// w.r.t. the given operation. The operation must be associative to provide correct results.
9/// 
10/// # Example
11/// ```rust
12/// # use feanor_math::algorithms::sqr_mul::generic_abs_square_and_multiply;
13/// # use feanor_math::primitive_int::*;
14/// let mut mul_count = 0;
15/// let mut square_count = 0;
16/// // using + instead of *, we can build any number from repeated additions of 1
17/// let result = generic_abs_square_and_multiply(
18///     1,
19///     &120481,
20///     StaticRing::<i64>::RING,
21///     |x| {
22///         square_count += 1;
23///         return x + x;
24///     },
25///     |x, y| {
26///         mul_count += 1;
27///         return x + y;
28///     },
29///     0
30/// );
31/// assert_eq!(120481, result);
32/// ```
33/// 
34pub fn generic_abs_square_and_multiply<T, U, F, H, I>(base: U, power: &El<I>, int_ring: I, mut square: F, mut multiply_base: H, identity: T) -> T
35    where I: RingStore,
36        I::Type: IntegerRing,
37        F: FnMut(T) -> T, 
38        H: FnMut(&U, T) -> T
39{
40    try_generic_abs_square_and_multiply(base, power, int_ring, |a| Ok(square(a)), |a, b| Ok(multiply_base(a, b)), identity).unwrap_or_else(|x| x)
41}
42
43///
44/// Uses the square-and-multiply technique to compute the reduction of `power` times `base`
45/// w.r.t. the given operation. The operation must be associative to provide correct results.
46/// 
47/// This function aborts as soon as any operation returns `Err(_)`.
48/// 
49#[stability::unstable(feature = "enable")]
50pub fn try_generic_abs_square_and_multiply<T, U, F, H, I, E>(base: U, power: &El<I>, int_ring: I, mut square: F, mut multiply_base: H, identity: T) -> Result<T, E>
51    where I: RingStore,
52        I::Type: IntegerRing,
53        F: FnMut(T) -> Result<T, E>, 
54        H: FnMut(&U, T) -> Result<T, E>
55{
56    if int_ring.is_zero(&power) {
57        return Ok(identity);
58    } else if int_ring.is_one(&power) {
59        return multiply_base(&base, identity);
60    }
61
62    let mut result = identity;
63    for i in (0..=int_ring.abs_highest_set_bit(power).unwrap()).rev() {
64        if int_ring.abs_is_bit_set(power, i) {
65            result = multiply_base(&base, square(result)?)?;
66        } else {
67            result = square(result)?;
68        }
69    }
70    return Ok(result);
71}
72
73///
74/// Computes the reduction of `power` times `base` w.r.t. the given operation.
75/// The operation must be associative to provide correct results.
76/// 
77/// The used algorithm relies on a decomposition of `power` and a table of small shortest addition 
78/// chains to heuristically reduce the number of operations compared to [`generic_abs_square_and_multiply()`].
79/// Note that this introduces some overhead, so in cases where the operation is very cheap, prefer
80/// [`generic_abs_square_and_multiply()`].
81/// 
82#[stability::unstable(feature = "enable")]
83pub fn generic_pow_shortest_chain_table<T, F, G, H, I, E>(base: T, power: &El<I>, int_ring: I, mut double: G, mut mul: F, mut clone: H, identity: T) -> Result<T, E>
84    where I: RingStore,
85        I::Type: IntegerRing,
86        F: FnMut(&T, &T) -> Result<T, E>, 
87        G: FnMut(&T) -> Result<T, E>, 
88        H: FnMut(&T) -> T
89{
90    assert!(!int_ring.is_neg(power));
91    if int_ring.is_zero(&power) {
92        return Ok(identity);
93    } else if int_ring.is_one(&power) {
94        return Ok(base);
95    }
96
97    let mut mult_count = 0;
98
99    const LOG2_BOUND: usize = 6;
100    const BOUND: usize = 1 << LOG2_BOUND;
101    assert!(SHORTEST_ADDITION_CHAINS.len() > BOUND);
102    let mut table = Vec::with_capacity(BOUND);
103    table.resize_with(BOUND + 1, || None);
104    table[0] = Some(identity);
105    table[1] = Some(base);
106
107    #[inline(always)]
108    fn eval_power_using_table<T, F, G, E>(power: usize, mul: &mut F, double: &mut G, table: &mut Vec<Option<T>>, mult_count: &mut usize) -> Result<(), E>
109        where F: FnMut(&T, &T) -> Result<T, E>,
110            G: FnMut(&T) -> Result<T, E>, 
111    {
112        if table[power].is_none() {
113            let (i, j) = SHORTEST_ADDITION_CHAINS[power];
114            eval_power_using_table(i, mul, double, table, mult_count)?;
115            eval_power_using_table(j, mul, double, table, mult_count)?;
116            if i == j {
117                *mult_count += 1;
118                table[power] = Some(double(table[i].as_ref().unwrap())?);
119            } else {
120                *mult_count += 1;
121                table[power] = Some(mul(table[i].as_ref().unwrap(), table[j].as_ref().unwrap())?);
122            }
123        }
124        return Ok(());
125    }
126
127    let bitlen = int_ring.abs_highest_set_bit(power).unwrap() + 1;
128    if bitlen < LOG2_BOUND {
129        let power = int_cast(int_ring.clone_el(&power), StaticRing::<i32>::RING, &int_ring) as usize;
130        eval_power_using_table(power, &mut mul, &mut double, &mut table, &mut mult_count)?;
131        return Ok(table.into_iter().nth(power).unwrap().unwrap());
132    }
133
134    let start_power = (0..LOG2_BOUND).filter(|j| int_ring.abs_is_bit_set(power, *j + bitlen - LOG2_BOUND)).map(|j| 1 << j).sum::<usize>();
135    eval_power_using_table(start_power, &mut mul, &mut double, &mut table, &mut mult_count)?;
136    let mut current = clone(table[start_power].as_ref().unwrap());
137
138    for i in (0..=(bitlen - LOG2_BOUND)).rev().step_by(LOG2_BOUND).skip(1) {
139        for _ in 0..LOG2_BOUND {
140            current = double(&current)?;
141            mult_count += 1;
142        }
143        let local_power = (0..LOG2_BOUND).filter(|j| int_ring.abs_is_bit_set(power, *j + i)).map(|j| 1 << j).sum::<usize>();
144        if local_power != 0 {
145            eval_power_using_table(local_power, &mut mul, &mut double, &mut table, &mut mult_count)?;
146            current = mul(&current, table[local_power].as_ref().unwrap())?;
147            mult_count += 1;
148        }
149    }
150
151    if bitlen % LOG2_BOUND != 0 {
152        let final_power = (0..(bitlen % LOG2_BOUND)).filter(|j| int_ring.abs_is_bit_set(power, *j)).map(|j| 1 << j).sum::<usize>();
153        eval_power_using_table(final_power, &mut mul, &mut double, &mut table, &mut mult_count)?;
154        
155        for _ in 0..(bitlen % LOG2_BOUND) {
156            current = double(&current)?;
157            mult_count += 1;
158        }
159        if final_power != 0 {
160            current = mul(&current, table[final_power].as_ref().unwrap())?;
161            mult_count += 1;
162        }
163    }
164
165    debug_assert!(mult_count <= bitlen * 2);
166
167    return Ok(current);
168}
169
170// The advantage of numbers < 128 is that the chains are extensions of each other,
171// i.e. we can choose each shortest chain such that also all its prefixes are chosen
172// shortest chains for corresponding numbers. The becomes impossible for 149.
173// data is from http://wwwhomes.uni-bielefeld.de/achim/addition_chain.html
174const SHORTEST_ADDITION_CHAINS: [(usize, usize); 65] = [
175    (0, 0),
176    (1, 0),
177    (1, 1),
178    (2, 1),
179    (2, 2),
180    (3, 2),
181    (3, 3),
182    (5, 2),
183    (4, 4),
184    (8, 1),
185    (5, 5),
186    (10, 1),
187    (6, 6),
188    (9, 4),
189    (7, 7),
190    (12, 3),
191    (8, 8),
192    (9, 8),
193    (16, 2),
194    (18, 1),
195    (10, 10),
196    (15, 6),
197    (11, 11),
198    (20, 3),
199    (12, 12),
200    (17, 8),
201    (13, 13),
202    (24, 3),
203    (14, 14),
204    (25, 4),
205    (15, 15),
206    (28, 3),
207    (16, 16),
208    (32, 1),
209    (17, 17),
210    (26, 9),
211    (18, 18),
212    (36, 1),
213    (19, 19),
214    (27, 12),
215    (20, 20),
216    (40, 1),
217    (21, 21),
218    (34, 9),
219    (22, 22),
220    (30, 15),
221    (23, 23),
222    (46, 1),
223    (24, 24),
224    (33, 16),
225    (25, 25),
226    (48, 3),
227    (26, 26),
228    (37, 16),
229    (27, 27),
230    (54, 1),
231    (28, 28),
232    (49, 8),
233    (29, 29),
234    (56, 3),
235    (30, 30),
236    (52, 9),
237    (31, 31),
238    (51, 12),
239    (32, 32)
240];
241
242#[cfg(test)]
243use test::Bencher;
244#[cfg(test)]
245use crate::rings::zn::zn_64;
246#[cfg(test)]
247use crate::homomorphism::*;
248
249#[test]
250fn test_generic_abs_square_and_multiply() {
251    for i in 0..(1 << 16) {
252        assert_eq!(Ok(i), try_generic_abs_square_and_multiply::<_, _, _, _, _, !>(1, &i, StaticRing::<i32>::RING, |a| Ok(a * 2), |a, b| Ok(a + b), 0));
253    }
254}
255
256#[test]
257fn test_generic_pow_shortest_chain_table() {
258    for i in 0..(1 << 16) {
259        assert_eq!(Ok(i), generic_pow_shortest_chain_table::<_, _, _, _, _, !>(1, &i, StaticRing::<i32>::RING, |a| Ok(a * 2), |a, b| Ok(a + b), |a| *a, 0));
260    }
261}
262
263#[test]
264fn test_shortest_addition_chain_table() {
265    for i in 0..SHORTEST_ADDITION_CHAINS.len() {
266        assert_eq!(i, SHORTEST_ADDITION_CHAINS[i].0 + SHORTEST_ADDITION_CHAINS[i].1);
267    }
268}
269
270#[bench]
271fn bench_standard_square_and_multiply(bencher: &mut Bencher) {
272    let ring = zn_64::Zn::new(536903681);
273    let x = ring.int_hom().map(2);
274    bencher.iter(|| {
275        assert_el_eq!(&ring, &ring.one(), try_generic_abs_square_and_multiply::<_, _, _, _, _, !>(
276            &x, 
277            &536903680, 
278            StaticRing::<i64>::RING, 
279            |mut res| {
280                ring.square(&mut res);
281                return Ok(res);
282            }, 
283            |a, b| Ok(ring.mul_ref_fst(a, b)), 
284            ring.one()
285        ).unwrap());
286    });
287}
288
289#[bench]
290fn bench_addchain_square_and_multiply(bencher: &mut Bencher) {
291    let ring = zn_64::Zn::new(536903681);
292    let x = ring.int_hom().map(2);
293    bencher.iter(|| {
294        assert_el_eq!(&ring, &ring.one(), generic_pow_shortest_chain_table::<_, _, _, _, _, !>(
295            x, 
296            &536903680, 
297            StaticRing::<i64>::RING, 
298            |a| {
299                let mut res = ring.clone_el(a);
300                ring.square(&mut res);
301                return Ok(res);
302            }, 
303            |a, b| Ok(ring.mul_ref(a, b)), 
304            |a| ring.clone_el(a),
305            ring.one()
306        ).unwrap());
307    });
308}