1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
//! Multiscalar multiplication implementation using pippenger algorithm.
use crate::{
    g1::{G1Affine, G1Projective},
    scalar::Scalar,
};
use byteorder;

#[cfg(feature = "std")]
/// Performs multiscalar multiplication reliying on Pippenger's algorithm.
/// This method was taken from `curve25519-dalek` and was originally made by
/// Oleg Andreev <oleganza@gmail.com>.
pub fn pippenger<P, I>(points: P, scalars: I) -> G1Projective
where
    P: Iterator<Item = G1Projective>,
    I: Iterator<Item = Scalar>,
{
    let size = scalars.size_hint().0;

    // Digit width in bits. As digit width grows,
    // number of point additions goes down, but amount of
    // buckets and bucket additions grows exponentially.
    let w = if size < 500 {
        6
    } else if size < 800 {
        7
    } else {
        8
    };

    let max_digit: usize = 1 << w;
    let digits_count: usize = to_radix_2w_size_hint(w);
    let buckets_count: usize = max_digit / 2; // digits are signed+centered hence 2^w/2, excluding 0-th bucket

    // Collect optimized scalars and points in buffers for repeated access
    // (scanning the whole set per digit position).
    let scalars = scalars.map(|s| to_radix_2w(&s, w));
    let scalars_points = scalars.zip(points).collect::<Vec<_>>();

    // Prepare 2^w/2 buckets.
    // buckets[i] corresponds to a multiplication factor (i+1).
    let mut buckets: Vec<_> = (0..buckets_count)
        .map(|_| G1Projective::identity())
        .collect();

    let mut columns = (0..digits_count).rev().map(|digit_index| {
        // Clear the buckets when processing another digit.
        for i in 0..buckets_count {
            buckets[i] = G1Projective::identity();
        }

        // Iterate over pairs of (point, scalar)
        // and add/sub the point to the corresponding bucket.
        // Note: if we add support for precomputed lookup tables,
        // we'll be adding/subtracting point premultiplied by `digits[i]` to buckets[0].
        for (digits, pt) in scalars_points.iter() {
            // Widen digit so that we don't run into edge cases when w=8.
            let digit = digits[digit_index] as i16;
            if digit > 0 {
                let b = (digit - 1) as usize;
                buckets[b] = buckets[b] + pt;
            } else if digit < 0 {
                let b = (-digit - 1) as usize;
                buckets[b] = buckets[b] - pt;
            }
        }

        // Add the buckets applying the multiplication factor to each bucket.
        // The most efficient way to do that is to have a single sum with two running sums:
        // an intermediate sum from last bucket to the first, and a sum of intermediate sums.
        //
        // For example, to add buckets 1*A, 2*B, 3*C we need to add these points:
        //   C
        //   C B
        //   C B A   Sum = C + (C+B) + (C+B+A)
        let mut buckets_intermediate_sum = buckets[buckets_count - 1];
        let mut buckets_sum = buckets[buckets_count - 1];
        for i in (0..(buckets_count - 1)).rev() {
            buckets_intermediate_sum += buckets[i];
            buckets_sum += buckets_intermediate_sum;
        }

        buckets_sum
    });

    // Take the high column as an initial value to avoid wasting time doubling the identity element in `fold()`.
    // `unwrap()` always succeeds because we know we have more than zero digits.
    let hi_column = columns.next().unwrap();

    columns.fold(hi_column, |total, p| mul_by_pow_2(&total, w as u32) + p)
}

/// Compute \\([2\^k] P \\) by successive doublings. Requires \\( k > 0 \\).
pub(crate) fn mul_by_pow_2(point: &G1Projective, k: u32) -> G1Projective {
    debug_assert!(k > 0);
    let mut r: G1Projective;
    let mut s = point;
    for _ in 0..(k - 1) {
        r = s.double();
        s = &r;
    }
    // Unroll last iteration so we can go directly to_extended()
    s.double()
}

/// Returns a size hint indicating how many entries of the return
/// value of `to_radix_2w` are nonzero.
fn to_radix_2w_size_hint(w: usize) -> usize {
    debug_assert!(w >= 6);
    debug_assert!(w <= 8);

    let digits_count = match w {
        6 => (256 + w - 1) / w as usize,
        7 => (256 + w - 1) / w as usize,
        // See comment in to_radix_2w on handling the terminal carry.
        8 => (256 + w - 1) / w + 1 as usize,
        _ => panic!("invalid radix parameter"),
    };

    debug_assert!(digits_count <= 43);
    digits_count
}

fn to_radix_2w(scalar: &Scalar, w: usize) -> [i8; 43] {
    debug_assert!(w >= 6);
    debug_assert!(w <= 8);

    use byteorder::{ByteOrder, LittleEndian};

    // Scalar formatted as four `u64`s with carry bit packed into the highest bit.
    let mut scalar64x4 = [0u64; 4];
    LittleEndian::read_u64_into(&scalar.to_bytes(), &mut scalar64x4[0..4]);

    let radix: u64 = 1 << w;
    let window_mask: u64 = radix - 1;

    let mut carry = 0u64;
    let mut digits = [0i8; 43];
    let digits_count = (256 + w - 1) / w as usize;
    for i in 0..digits_count {
        // Construct a buffer of bits of the scalar, starting at `bit_offset`.
        let bit_offset = i * w;
        let u64_idx = bit_offset / 64;
        let bit_idx = bit_offset % 64;

        // Read the bits from the scalar
        let bit_buf: u64;
        if bit_idx < 64 - w || u64_idx == 3 {
            // This window's bits are contained in a single u64,
            // or it's the last u64 anyway.
            bit_buf = scalar64x4[u64_idx] >> bit_idx;
        } else {
            // Combine the current u64's bits with the bits from the next u64
            bit_buf =
                (scalar64x4[u64_idx] >> bit_idx) | (scalar64x4[1 + u64_idx] << (64 - bit_idx));
        }

        // Read the actual coefficient value from the window
        let coef = carry + (bit_buf & window_mask); // coef = [0, 2^r)

        // Recenter coefficients from [0,2^w) to [-2^w/2, 2^w/2)
        carry = (coef + (radix / 2) as u64) >> w;
        digits[i] = ((coef as i64) - (carry << w) as i64) as i8;
    }

    // When w < 8, we can fold the final carry onto the last digit d,
    // because d < 2^w/2 so d + carry*2^w = d + 1*2^w < 2^(w+1) < 2^8.
    //
    // When w = 8, we can't fit carry*2^w into an i8.  This should
    // not happen anyways, because the final carry will be 0 for
    // reduced scalars, but the Scalar invariant allows 255-bit scalars.
    // To handle this, we expand the size_hint by 1 when w=8,
    // and accumulate the final carry onto another digit.
    match w {
        8 => digits[digits_count] += carry as i8,
        _ => digits[digits_count - 1] += (carry << w) as i8,
    }

    digits
}

#[cfg(feature = "std")]
/// Performs a Variable Base Multiscalar Multiplication.
pub fn msm_variable_base(points: &[G1Affine], scalars: &[Scalar]) -> G1Projective {
    use rayon::prelude::*;

    let c = if scalars.len() < 32 {
        3
    } else {
        ln_without_floats(scalars.len()) + 2
    };

    let num_bits = 255usize;
    let fr_one = Scalar::one();

    let zero = G1Projective::identity();
    let window_starts: Vec<_> = (0..num_bits).step_by(c).collect();

    let window_starts_iter = window_starts.into_par_iter();

    // Each window is of size `c`.
    // We divide up the bits 0..num_bits into windows of size `c`, and
    // in parallel process each such window.
    let window_sums: Vec<_> = window_starts_iter
        .map(|w_start| {
            let mut res = zero;
            // We don't need the "zero" bucket, so we only have 2^c - 1 buckets
            let mut buckets = vec![zero; (1 << c) - 1];
            scalars
                .iter()
                .zip(points)
                .filter(|(s, _)| !(*s == &Scalar::zero()))
                .for_each(|(&scalar, base)| {
                    if scalar == fr_one {
                        // We only process unit scalars once in the first window.
                        if w_start == 0 {
                            res = res.add_mixed(base);
                        }
                    } else {
                        let mut scalar = scalar.reduce();

                        // We right-shift by w_start, thus getting rid of the
                        // lower bits.
                        scalar.divn(w_start as u32);

                        // We mod the remaining bits by the window size.
                        let scalar = scalar.0[0] % (1 << c);

                        // If the scalar is non-zero, we update the corresponding
                        // bucket.
                        // (Recall that `buckets` doesn't have a zero bucket.)
                        if scalar != 0 {
                            buckets[(scalar - 1) as usize] =
                                buckets[(scalar - 1) as usize].add_mixed(base);
                        }
                    }
                });

            let mut running_sum = G1Projective::identity();
            for b in buckets.into_iter().rev() {
                running_sum = running_sum + b;
                res += &running_sum;
            }

            res
        })
        .collect();

    // We store the sum for the lowest window.
    let lowest = *window_sums.first().unwrap();
    // We're traversing windows from high to low.
    window_sums[1..]
        .iter()
        .rev()
        .fold(zero, |mut total, sum_i| {
            total += sum_i;
            for _ in 0..c {
                total = total.double();
            }
            total
        })
        + lowest
}

fn ln_without_floats(a: usize) -> usize {
    // log2(a) * ln(2)
    (log2(a) * 69 / 100) as usize
}
fn log2(x: usize) -> u32 {
    if x <= 1 {
        return 0;
    }

    let n = x.leading_zeros();
    core::mem::size_of::<usize>() as u32 * 8 - n
}

mod tests {
    #[allow(unused_imports)]
    use super::*;

    #[cfg(feature = "std")]
    #[test]
    fn pippenger_test() {
        // Reuse points across different tests
        let mut n = 512;
        let x = Scalar::from(2128506u64).invert().unwrap();
        let y = Scalar::from(4443282u64).invert().unwrap();
        let points = (0..n)
            .map(|i| G1Projective::generator() * Scalar::from(1 + i as u64))
            .collect::<Vec<_>>();
        let scalars = (0..n)
            .map(|i| x + (Scalar::from(i as u64) * y))
            .collect::<Vec<_>>(); // fast way to make ~random but deterministic scalars
        let premultiplied: Vec<G1Projective> = scalars
            .iter()
            .zip(points.iter())
            .map(|(sc, pt)| pt * sc)
            .collect();
        while n > 0 {
            let scalars = &scalars[0..n];
            let points = &points[0..n];
            let control: G1Projective = premultiplied[0..n].iter().sum();
            let subject = pippenger(
                points.to_owned().into_iter(),
                scalars.to_owned().into_iter(),
            );
            assert_eq!(subject, control);
            n = n / 2;
        }
    }

    #[cfg(feature = "std")]
    #[test]
    fn msm_variable_base_test() {
        let points = vec![G1Affine::generator()];
        let scalars = vec![Scalar::from(100u64)];
        let premultiplied = G1Projective::generator() * Scalar::from(100u64);
        let subject = msm_variable_base(&points, &scalars);
        assert_eq!(subject, premultiplied);
    }
}