curve25519_dalek/backend/vector/scalar_mul/
pippenger.rs1#![allow(non_snake_case)]
11
12#[curve25519_dalek_derive::unsafe_target_feature_specialize(
13 "avx2",
14 conditional(
15 "avx512ifma,avx512vl",
16 all(curve25519_dalek_backend = "unstable_avx512", nightly)
17 )
18)]
19pub mod spec {
20
21 use alloc::vec::Vec;
22
23 use core::borrow::Borrow;
24 use core::cmp::Ordering;
25
26 #[for_target_feature("avx2")]
27 use crate::backend::vector::avx2::{CachedPoint, ExtendedPoint};
28
29 #[for_target_feature("avx512ifma")]
30 use crate::backend::vector::ifma::{CachedPoint, ExtendedPoint};
31
32 use crate::edwards::EdwardsPoint;
33 use crate::scalar::Scalar;
34 use crate::traits::{Identity, VartimeMultiscalarMul};
35
36 pub struct Pippenger;
40
41 impl VartimeMultiscalarMul for Pippenger {
42 type Point = EdwardsPoint;
43
44 fn optional_multiscalar_mul<I, J>(scalars: I, points: J) -> Option<EdwardsPoint>
45 where
46 I: IntoIterator,
47 I::Item: Borrow<Scalar>,
48 J: IntoIterator<Item = Option<EdwardsPoint>>,
49 {
50 let mut scalars = scalars.into_iter();
51 let size = scalars.by_ref().size_hint().0;
52 let w = if size < 500 {
53 6
54 } else if size < 800 {
55 7
56 } else {
57 8
58 };
59
60 let max_digit: usize = 1 << w;
61 let digits_count: usize = Scalar::to_radix_2w_size_hint(w);
62 let buckets_count: usize = max_digit / 2; let scalars = scalars.map(|s| s.borrow().as_radix_2w(w));
67
68 let points = points
69 .into_iter()
70 .map(|p| p.map(|P| CachedPoint::from(ExtendedPoint::from(P))));
71
72 let scalars_points = scalars
73 .zip(points)
74 .map(|(s, maybe_p)| maybe_p.map(|p| (s, p)))
75 .collect::<Option<Vec<_>>>()?;
76
77 let mut buckets: Vec<ExtendedPoint> = (0..buckets_count)
80 .map(|_| ExtendedPoint::identity())
81 .collect();
82
83 let mut columns = (0..digits_count).rev().map(|digit_index| {
84 for bucket in &mut buckets {
86 *bucket = ExtendedPoint::identity();
87 }
88
89 for (digits, pt) in scalars_points.iter() {
94 let digit = digits[digit_index] as i16;
96 match digit.cmp(&0) {
97 Ordering::Greater => {
98 let b = (digit - 1) as usize;
99 buckets[b] = &buckets[b] + pt;
100 }
101 Ordering::Less => {
102 let b = (-digit - 1) as usize;
103 buckets[b] = &buckets[b] - pt;
104 }
105 Ordering::Equal => {}
106 }
107 }
108
109 let mut buckets_intermediate_sum = buckets[buckets_count - 1];
118 let mut buckets_sum = buckets[buckets_count - 1];
119 for i in (0..(buckets_count - 1)).rev() {
120 buckets_intermediate_sum =
121 &buckets_intermediate_sum + &CachedPoint::from(buckets[i]);
122 buckets_sum = &buckets_sum + &CachedPoint::from(buckets_intermediate_sum);
123 }
124
125 buckets_sum
126 });
127
128 let hi_column = columns.next().expect("should have more than zero digits");
130
131 Some(
132 columns
133 .fold(hi_column, |total, p| {
134 &total.mul_by_pow_2(w as u32) + &CachedPoint::from(p)
135 })
136 .into(),
137 )
138 }
139 }
140
141 #[cfg(test)]
142 mod test {
143 #[test]
144 fn test_vartime_pippenger() {
145 use super::*;
146 use crate::constants;
147 use crate::scalar::Scalar;
148
149 let mut n = 512;
151 let x = Scalar::from(2128506u64).invert();
152 let y = Scalar::from(4443282u64).invert();
153 let points: Vec<_> = (0..n)
154 .map(|i| constants::ED25519_BASEPOINT_POINT * Scalar::from(1 + i as u64))
155 .collect();
156 let scalars: Vec<_> = (0..n)
157 .map(|i| x + (Scalar::from(i as u64) * y)) .collect();
159
160 let premultiplied: Vec<EdwardsPoint> = scalars
161 .iter()
162 .zip(points.iter())
163 .map(|(sc, pt)| sc * pt)
164 .collect();
165
166 while n > 0 {
167 let scalars = &scalars[0..n].to_vec();
168 let points = &points[0..n].to_vec();
169 let control: EdwardsPoint = premultiplied[0..n].iter().sum();
170
171 let subject = Pippenger::vartime_multiscalar_mul(scalars.clone(), points.clone());
172
173 assert_eq!(subject.compress(), control.compress());
174
175 n = n / 2;
176 }
177 }
178 }
179}