generic_ec/multiscalar/straus.rs
1use alloc::{vec, vec::Vec};
2use core::iter;
3
4use crate::{Curve, Point, Scalar};
5
6/// Straus algorithm
7///
8/// # How it works
9/// Below we'll briefly explain how the algorithm works for better auditability. You can
10/// also refer to [original](#credits) implementation.
11///
12/// Note that algorithm is defined for a parameter $w$, however, in our implementation we hardcoded
13/// $w = 5$. It was observed in the benchmarks that $w=5$ gives the best performance for all
14/// $n$ (amount of input scalar/point pairs).
15///
16/// Recall that the multiscalar algorithm takes list of $n$ points $P_1, \dots, P_n$, and a list
17/// of $n$ scalars $s_1, \dots, s_n$, and it outputs $Q$ such that:
18///
19/// $$Q = s_1 P_1 + \dots + s_n P_n$$
20///
21/// ## Non-Adjacent Form (NAF)
22/// Straus algorithm works with scalars in Non-Adjacent Form. Each scalar $s$ is represented
23/// as:
24///
25/// $$s = s_0 2^0 + s_1 2^1 + \dots + s_k 2^k$$
26///
27/// where $-2^{w-1} \le s_i < 2^{w-1}$, each $s_i$ is either odd or zero, and $k = \log_2 s$ (most
28/// commonly, we work with scalars for which $k=256$).
29///
30/// ## Lookup tables
31/// For each point $P_i$, we precompute a lookup table $T_j = j P_i$. We only need to do that
32/// for odd $j$ up to $2^{w-1}$. In this way, NAF allows us to cut size of lookup tables by
33/// the factor of 4: we reduce size of tables by 2 because NAF has signed terms, and then we
34/// also reduce table size twice by working only with odd coefficients.
35///
36/// ## Computing the sum
37///
38/// Let's write the full sum that we need to compute:
39/// $$
40/// \begin{aligned}
41/// s_1 P_1 &=&& s_{1,0} P_1 &&+&& 2^1 s_{1,1} P_1 &&+ \dots +&& 2^{k} s_{1,k-1} P_1 \\\\
42/// \+ & && + && && + && && + \\\\
43/// s_2 P_2 &=&& s_{2,0} P_2 &&+&& 2^1 s_{2,1} P_2 &&+ \dots +&& 2^{k} s_{2,k-1} P_2 \\\\
44/// \+ & && + && && + && && + \\\\
45/// \vdots & && \vdots && && \vdots && && \vdots \\\\
46/// \+ & && + && && + && && + \\\\
47/// s_n P_n &=&& s_{n,0} P_n &&+&& 2^1 s_{n,1} P_n &&+ \dots +&& 2^{k-1} s_{n,k-1} P_n
48/// \end{aligned}
49/// $$
50///
51/// Note that each $s_{i,j} P_i$ is already computed in a lookup table, and can be replaced with
52/// $T_{i, s_{i,j}}$. To compute a sum, we go column-by-column from right to left.
53///
54/// $$
55/// \begin{aligned}
56/// Q_k &= & &\sum_{i = 0}^n T_{i, s_{i,k}} \\\\
57/// Q_{k-1} &= &2 Q_k + &\sum_{i = 0}^n T_{i, s_{i,k-1}} \\\\
58/// \vdots & & &\\\\
59/// Q_j &= &2 Q_{j + 1} + &\sum_{i = 0}^n T_{i, s_{i,j}} \\\\
60/// \vdots & & &\\\\
61/// Q = Q_0 &= &2 Q_1 + &\sum_{i = 0}^n T_{i, s_{i,0}} \\\\
62/// \end{aligned}
63/// $$
64///
65/// ## Credits
66/// Algorithm was adopted from [`curve25519_dalek`](curve25519) crate, with the modification that
67/// it would work with any curve, not only with ed25519. You can find original implementation
68/// [here](https://github.com/dalek-cryptography/curve25519-dalek/blob/1efe6a93b176c4389b78e81e52b2cf85d728aac6/curve25519-dalek/src/backend/serial/scalar_mul/straus.rs#L147-L201).
69pub struct Straus;
70
71impl<E: Curve> super::MultiscalarMul<E> for Straus {
72 fn multiscalar_mul<S, P>(
73 scalar_points: impl ExactSizeIterator<Item = (S, P)>,
74 ) -> crate::Point<E>
75 where
76 S: AsRef<Scalar<E>>,
77 P: AsRef<Point<E>>,
78 {
79 let mut nafs = NafMatrix::new(5, scalar_points.len());
80 let lookup_tables: Vec<_> = scalar_points
81 .into_iter()
82 .map(|(scalar, point)| {
83 nafs.add_scalar(scalar.as_ref());
84 point
85 })
86 .map(|point| LookupTable::new(*point.as_ref()))
87 .collect();
88 if lookup_tables.is_empty() {
89 return Point::zero();
90 }
91
92 let naf_size = nafs.naf_size;
93
94 let mut r = Point::zero();
95 for (i, is_first_iter) in (0..naf_size)
96 .rev()
97 .zip(iter::once(true).chain(iter::repeat(false)))
98 {
99 if !is_first_iter {
100 r = r.double();
101 }
102 for (naf, lookup_table) in nafs.iter().zip(&lookup_tables) {
103 let naf_i = naf[i];
104 match naf_i.cmp(&0) {
105 core::cmp::Ordering::Greater => {
106 r += lookup_table.get(naf_i.unsigned_abs().into());
107 }
108 core::cmp::Ordering::Less => {
109 r -= lookup_table.get(naf_i.unsigned_abs().into());
110 }
111 core::cmp::Ordering::Equal => {}
112 }
113 }
114 }
115 r
116 }
117}
118
119struct LookupTable<E: Curve>([Point<E>; 8]);
120
121impl<E: Curve> LookupTable<E> {
122 /// Builds a lookup table for point $P$
123 fn new(point: Point<E>) -> Self {
124 let mut table = [point; 8];
125 let point2 = point.double();
126 for i in 0..7 {
127 table[i + 1] = point2 + table[i];
128 }
129 Self(table)
130 }
131 /// Takes odd integer $x$ such as $0 < x < 2^4$, returns $x P$
132 fn get(&self, x: usize) -> Point<E> {
133 debug_assert_eq!(x & 1, 1);
134 debug_assert!(x < 16);
135
136 self.0[x / 2]
137 }
138}
139
140/// Stores a width-$w$ "Non-Adjacent Form" (NAF) of multiple scalars
141///
142/// Width-$w$ NAF represents an integer $k$ via coefficients $k_0, \dots, k_n$ such as:
143///
144/// $$k = \sum_{i=0}^{n} k_i \cdot 2^i$$
145///
146/// where each $k_i$ is odd and lies within range $-2^{w-1} \le k_i < 2^{w-1}$.
147///
148/// Non Adjacent Form allows us to reduce size of tables we need to precompute in Straus
149/// multiscalar multiplication by factor of 4.
150struct NafMatrix<E: Curve> {
151 /// Size of one scalar in non adjacent form
152 naf_size: usize,
153 /// Input parameter `w`
154 w: usize,
155 /// width = 2^w
156 width: u64,
157 /// width_half = width / 2
158 width_half: u64,
159 /// window_mask = width - 1
160 window_mask: u64,
161 matrix: Vec<i8>,
162
163 _curve: core::marker::PhantomData<E>,
164}
165
166impl<E: Curve> NafMatrix<E> {
167 /// Construct a new matrix with parameter `w`
168 ///
169 /// Preallocates memory to fit `capacity` amount of scalars
170 fn new(w: usize, capacity: usize) -> Self {
171 assert!((2..=8).contains(&w));
172 let naf_size = Scalar::<E>::serialized_len() * 8 + 1;
173 let width = 1 << w;
174
175 Self {
176 naf_size,
177 w,
178 width,
179 width_half: 1 << (w - 1),
180 matrix: Vec::with_capacity(naf_size * capacity),
181 window_mask: width - 1,
182 _curve: Default::default(),
183 }
184 }
185 /// Adds a scalar into matrix
186 fn add_scalar(&mut self, scalar: &Scalar<E>) {
187 let scalar_bytes = scalar.to_le_bytes();
188 let mut x_u64 = vec![0u64; scalar_bytes.len() / 8 + 1];
189 read_le_u64_into(&scalar_bytes, &mut x_u64[0..4]);
190
191 let offset = self.matrix.len();
192 debug_assert!(
193 offset + self.naf_size <= self.matrix.capacity(),
194 "unnecessary allocations detected"
195 );
196 self.matrix.resize(offset + self.naf_size, 0i8);
197 let naf = &mut self.matrix[offset..];
198
199 let mut pos = 0;
200 let mut carry = false;
201 while pos < self.naf_size {
202 let u64_idx = pos / 64;
203 let bit_idx = pos % 64;
204 let bit_buf: u64 = if bit_idx < 64 - self.w {
205 // This window bits are contained in a single u64
206 (x_u64[u64_idx] >> bit_idx) & self.window_mask
207 } else {
208 // Combine the current u64's bits with the bits from the next u64
209 ((x_u64[u64_idx] >> bit_idx) | (x_u64[u64_idx + 1] << (64 - bit_idx)))
210 & self.window_mask
211 };
212
213 // Add the carry into the current window
214 let window = if carry { bit_buf + 1 } else { bit_buf };
215
216 if window & 1 == 0 {
217 // If the window value is even, preserve the carry and continue.
218 // Why is the carry preserved?
219 // If carry == 0 and window & 1 == 0, then the next carry should be 0
220 // If carry == 1 and window & 1 == 0, then bit_buf & 1 == 1 so the next carry should be 1
221 pos += 1;
222 continue;
223 }
224
225 if window < self.width_half {
226 carry = false;
227 naf[pos] = window as i8;
228 } else {
229 carry = true;
230 naf[pos] = (window as i8).wrapping_sub(self.width as i8);
231 }
232
233 pos += self.w;
234 }
235
236 debug_assert!(!carry);
237 }
238
239 /// Iterates over scalars NAF representations in the same order as
240 /// scalars were added into the matrix
241 fn iter(&self) -> impl Iterator<Item = &[i8]> {
242 self.matrix.chunks_exact(self.naf_size)
243 }
244}
245
246/// Read one or more u64s stored as little endian bytes.
247///
248/// ## Panics
249/// Panics if `src.len() != 8 * dst.len()`.
250fn read_le_u64_into(src: &[u8], dst: &mut [u64]) {
251 assert!(
252 src.len() == 8 * dst.len(),
253 "src.len() = {}, dst.len() = {}",
254 src.len(),
255 dst.len()
256 );
257 for (bytes, val) in src.chunks(8).zip(dst.iter_mut()) {
258 *val = u64::from_le_bytes(
259 #[allow(clippy::expect_used)]
260 bytes
261 .try_into()
262 .expect("Incorrect src length, should be 8 * dst.len()"),
263 );
264 }
265}
266
267#[cfg(test)]
268#[generic_tests::define]
269mod tests {
270 use alloc::vec::Vec;
271 use core::iter;
272
273 use crate::{Curve, Point, Scalar};
274
275 #[test]
276 fn non_adjacent_form_is_correct<E: Curve>() {
277 let mut rng = rand_dev::DevRng::new();
278
279 let scalars = iter::once(Scalar::<E>::zero())
280 .chain(iter::once(Scalar::one()))
281 .chain(iter::once(-Scalar::one()))
282 .chain(iter::repeat_with(|| Scalar::random(&mut rng)).take(15))
283 .collect::<Vec<_>>();
284
285 for w in 2..=8 {
286 let mut nafs = super::NafMatrix::new(w, scalars.len());
287 scalars.iter().for_each(|scalar| nafs.add_scalar(scalar));
288
289 for (scalar, naf) in scalars.iter().zip(nafs.iter()) {
290 std::eprintln!("scalar {scalar:?}");
291 std::eprintln!("naf: {naf:?}");
292
293 assert!(naf.iter().all(|&k_i| -(1i16 << (w - 1)) <= i16::from(k_i)
294 && i16::from(k_i) < (1i16 << (w - 1))));
295
296 let expected = naf.iter().rev().fold(Scalar::<E>::zero(), |acc, naf_i| {
297 acc + acc + Scalar::from(*naf_i)
298 });
299 assert_eq!(*scalar, expected)
300 }
301 }
302 }
303
304 #[test]
305 fn lookup_table<E: Curve>() {
306 let mut rng = rand_dev::DevRng::new();
307
308 let points = iter::once(Point::<E>::generator().to_point())
309 .chain(iter::repeat_with(|| Scalar::random(&mut rng) * Point::generator()).take(50));
310 for point in points {
311 let table = super::LookupTable::new(point);
312
313 for x in (1..16).step_by(2) {
314 assert_eq!(table.get(x), point * Scalar::from(x));
315 }
316 }
317 }
318
319 #[instantiate_tests(<crate::curves::Secp256k1>)]
320 mod secp256k1 {}
321 #[instantiate_tests(<crate::curves::Secp256r1>)]
322 mod secp256r1 {}
323 #[instantiate_tests(<crate::curves::Stark>)]
324 mod stark {}
325 #[instantiate_tests(<crate::curves::Ed25519>)]
326 mod ed25519 {}
327}