curve25519_dalek/backend/vector/scalar_mul/
pippenger.rs

1// -*- mode: rust; -*-
2//
3// This file is part of curve25519-dalek.
4// Copyright (c) 2019 Oleg Andreev
5// See LICENSE for licensing information.
6//
7// Authors:
8// - Oleg Andreev <oleganza@gmail.com>
9
10#![allow(non_snake_case)]
11
12#[curve25519_dalek_derive::unsafe_target_feature_specialize(
13    "avx2",
14    conditional(
15        "avx512ifma,avx512vl",
16        all(curve25519_dalek_backend = "unstable_avx512", nightly)
17    )
18)]
19pub mod spec {
20
21    use alloc::vec::Vec;
22
23    use core::borrow::Borrow;
24    use core::cmp::Ordering;
25
26    #[for_target_feature("avx2")]
27    use crate::backend::vector::avx2::{CachedPoint, ExtendedPoint};
28
29    #[for_target_feature("avx512ifma")]
30    use crate::backend::vector::ifma::{CachedPoint, ExtendedPoint};
31
32    use crate::edwards::EdwardsPoint;
33    use crate::scalar::Scalar;
34    use crate::traits::{Identity, VartimeMultiscalarMul};
35
36    /// Implements a version of Pippenger's algorithm.
37    ///
38    /// See the documentation in the serial `scalar_mul::pippenger` module for details.
39    pub struct Pippenger;
40
41    impl VartimeMultiscalarMul for Pippenger {
42        type Point = EdwardsPoint;
43
44        fn optional_multiscalar_mul<I, J>(scalars: I, points: J) -> Option<EdwardsPoint>
45        where
46            I: IntoIterator,
47            I::Item: Borrow<Scalar>,
48            J: IntoIterator<Item = Option<EdwardsPoint>>,
49        {
50            let mut scalars = scalars.into_iter();
51            let size = scalars.by_ref().size_hint().0;
52            let w = if size < 500 {
53                6
54            } else if size < 800 {
55                7
56            } else {
57                8
58            };
59
60            let max_digit: usize = 1 << w;
61            let digits_count: usize = Scalar::to_radix_2w_size_hint(w);
62            let buckets_count: usize = max_digit / 2; // digits are signed+centered hence 2^w/2, excluding 0-th bucket
63
64            // Collect optimized scalars and points in a buffer for repeated access
65            // (scanning the whole collection per each digit position).
66            let scalars = scalars.map(|s| s.borrow().as_radix_2w(w));
67
68            let points = points
69                .into_iter()
70                .map(|p| p.map(|P| CachedPoint::from(ExtendedPoint::from(P))));
71
72            let scalars_points = scalars
73                .zip(points)
74                .map(|(s, maybe_p)| maybe_p.map(|p| (s, p)))
75                .collect::<Option<Vec<_>>>()?;
76
77            // Prepare 2^w/2 buckets.
78            // buckets[i] corresponds to a multiplication factor (i+1).
79            let mut buckets: Vec<ExtendedPoint> = (0..buckets_count)
80                .map(|_| ExtendedPoint::identity())
81                .collect();
82
83            let mut columns = (0..digits_count).rev().map(|digit_index| {
84                // Clear the buckets when processing another digit.
85                for bucket in &mut buckets {
86                    *bucket = ExtendedPoint::identity();
87                }
88
89                // Iterate over pairs of (point, scalar)
90                // and add/sub the point to the corresponding bucket.
91                // Note: if we add support for precomputed lookup tables,
92                // we'll be adding/subtractiong point premultiplied by `digits[i]` to buckets[0].
93                for (digits, pt) in scalars_points.iter() {
94                    // Widen digit so that we don't run into edge cases when w=8.
95                    let digit = digits[digit_index] as i16;
96                    match digit.cmp(&0) {
97                        Ordering::Greater => {
98                            let b = (digit - 1) as usize;
99                            buckets[b] = &buckets[b] + pt;
100                        }
101                        Ordering::Less => {
102                            let b = (-digit - 1) as usize;
103                            buckets[b] = &buckets[b] - pt;
104                        }
105                        Ordering::Equal => {}
106                    }
107                }
108
109                // Add the buckets applying the multiplication factor to each bucket.
110                // The most efficient way to do that is to have a single sum with two running sums:
111                // an intermediate sum from last bucket to the first, and a sum of intermediate sums.
112                //
113                // For example, to add buckets 1*A, 2*B, 3*C we need to add these points:
114                //   C
115                //   C B
116                //   C B A   Sum = C + (C+B) + (C+B+A)
117                let mut buckets_intermediate_sum = buckets[buckets_count - 1];
118                let mut buckets_sum = buckets[buckets_count - 1];
119                for i in (0..(buckets_count - 1)).rev() {
120                    buckets_intermediate_sum =
121                        &buckets_intermediate_sum + &CachedPoint::from(buckets[i]);
122                    buckets_sum = &buckets_sum + &CachedPoint::from(buckets_intermediate_sum);
123                }
124
125                buckets_sum
126            });
127
128            // Take the high column as an initial value to avoid wasting time doubling the identity element in `fold()`.
129            let hi_column = columns.next().expect("should have more than zero digits");
130
131            Some(
132                columns
133                    .fold(hi_column, |total, p| {
134                        &total.mul_by_pow_2(w as u32) + &CachedPoint::from(p)
135                    })
136                    .into(),
137            )
138        }
139    }
140
141    #[cfg(test)]
142    mod test {
143        #[test]
144        fn test_vartime_pippenger() {
145            use super::*;
146            use crate::constants;
147            use crate::scalar::Scalar;
148
149            // Reuse points across different tests
150            let mut n = 512;
151            let x = Scalar::from(2128506u64).invert();
152            let y = Scalar::from(4443282u64).invert();
153            let points: Vec<_> = (0..n)
154                .map(|i| constants::ED25519_BASEPOINT_POINT * Scalar::from(1 + i as u64))
155                .collect();
156            let scalars: Vec<_> = (0..n)
157                .map(|i| x + (Scalar::from(i as u64) * y)) // fast way to make ~random but deterministic scalars
158                .collect();
159
160            let premultiplied: Vec<EdwardsPoint> = scalars
161                .iter()
162                .zip(points.iter())
163                .map(|(sc, pt)| sc * pt)
164                .collect();
165
166            while n > 0 {
167                let scalars = &scalars[0..n].to_vec();
168                let points = &points[0..n].to_vec();
169                let control: EdwardsPoint = premultiplied[0..n].iter().sum();
170
171                let subject = Pippenger::vartime_multiscalar_mul(scalars.clone(), points.clone());
172
173                assert_eq!(subject.compress(), control.compress());
174
175                n = n / 2;
176            }
177        }
178    }
179}