generic_ec/multiscalar/
straus.rs

1use alloc::{vec, vec::Vec};
2use core::iter;
3
4use crate::{Curve, Point, Scalar};
5
6/// Straus algorithm
7///
8/// # How it works
9/// Below we'll briefly explain how the algorithm works for better auditability. You can
10/// also refer to [original](#credits) implementation.
11///
12/// Note that algorithm is defined for a parameter $w$, however, in our implementation we hardcoded
13/// $w = 5$. It was observed in the benchmarks that $w=5$ gives the best performance for all
14/// $n$ (amount of input scalar/point pairs).
15///
16/// Recall that the multiscalar algorithm takes list of $n$ points $P_1, \dots, P_n$, and a list
17/// of $n$ scalars $s_1, \dots, s_n$, and it outputs $Q$ such that:
18///
19/// $$Q = s_1 P_1 + \dots + s_n P_n$$
20///
21/// ## Non-Adjacent Form (NAF)
22/// Straus algorithm works with scalars in Non-Adjacent Form. Each scalar $s$ is represented
23/// as:
24///
25/// $$s = s_0 2^0 + s_1 2^1 + \dots + s_k 2^k$$
26///
27/// where $-2^{w-1} \le s_i < 2^{w-1}$, each $s_i$ is either odd or zero, and $k = \log_2 s$ (most
28/// commonly, we work with scalars for which $k=256$).
29///
30/// ## Lookup tables
31/// For each point $P_i$, we precompute a lookup table $T_j = j P_i$. We only need to do that
32/// for odd $j$ up to $2^{w-1}$. In this way, NAF allows us to cut size of lookup tables by
33/// the factor of 4: we reduce size of tables by 2 because NAF has signed terms, and then we
34/// also reduce table size twice by working only with odd coefficients.
35///
36/// ## Computing the sum
37///
38/// Let's write the full sum that we need to compute:
39/// $$
40/// \begin{aligned}
41/// s_1 P_1 &=&& s_{1,0} P_1 &&+&& 2^1 s_{1,1} P_1 &&+ \dots +&& 2^{k} s_{1,k-1} P_1 \\\\
42///    \+   & &&        +    && &&            +    &&         &&                +    \\\\
43/// s_2 P_2 &=&& s_{2,0} P_2 &&+&& 2^1 s_{2,1} P_2 &&+ \dots +&& 2^{k} s_{2,k-1} P_2 \\\\
44///    \+   & &&        +    && &&            +    &&         &&                +    \\\\
45/// \vdots  & && \vdots      && && \vdots          &&         && \vdots              \\\\
46///    \+   & &&        +    && &&            +    &&         &&                +    \\\\
47/// s_n P_n &=&& s_{n,0} P_n &&+&& 2^1 s_{n,1} P_n &&+ \dots +&& 2^{k-1} s_{n,k-1} P_n
48/// \end{aligned}
49/// $$
50///
51/// Note that each $s_{i,j} P_i$ is already computed in a lookup table, and can be replaced with
52/// $T_{i, s_{i,j}}$. To compute a sum, we go column-by-column from right to left.
53///
54/// $$
55/// \begin{aligned}
56/// Q_k     &= &              &\sum_{i = 0}^n T_{i, s_{i,k}} \\\\
57/// Q_{k-1} &= &2 Q_k +       &\sum_{i = 0}^n T_{i, s_{i,k-1}} \\\\
58/// \vdots  &  &              &\\\\
59/// Q_j     &= &2 Q_{j + 1} + &\sum_{i = 0}^n T_{i, s_{i,j}} \\\\
60/// \vdots  &  &              &\\\\
61/// Q = Q_0 &= &2 Q_1 +       &\sum_{i = 0}^n T_{i, s_{i,0}} \\\\
62/// \end{aligned}
63/// $$
64///
65/// ## Credits
66/// Algorithm was adopted from [`curve25519_dalek`](curve25519) crate, with the modification that
67/// it would work with any curve, not only with ed25519. You can find original implementation
68/// [here](https://github.com/dalek-cryptography/curve25519-dalek/blob/1efe6a93b176c4389b78e81e52b2cf85d728aac6/curve25519-dalek/src/backend/serial/scalar_mul/straus.rs#L147-L201).
69pub struct Straus;
70
71impl<E: Curve> super::MultiscalarMul<E> for Straus {
72    fn multiscalar_mul<S, P>(
73        scalar_points: impl ExactSizeIterator<Item = (S, P)>,
74    ) -> crate::Point<E>
75    where
76        S: AsRef<Scalar<E>>,
77        P: AsRef<Point<E>>,
78    {
79        let mut nafs = NafMatrix::new(5, scalar_points.len());
80        let lookup_tables: Vec<_> = scalar_points
81            .into_iter()
82            .map(|(scalar, point)| {
83                nafs.add_scalar(scalar.as_ref());
84                point
85            })
86            .map(|point| LookupTable::new(*point.as_ref()))
87            .collect();
88        if lookup_tables.is_empty() {
89            return Point::zero();
90        }
91
92        let naf_size = nafs.naf_size;
93
94        let mut r = Point::zero();
95        for (i, is_first_iter) in (0..naf_size)
96            .rev()
97            .zip(iter::once(true).chain(iter::repeat(false)))
98        {
99            if !is_first_iter {
100                r = r.double();
101            }
102            for (naf, lookup_table) in nafs.iter().zip(&lookup_tables) {
103                let naf_i = naf[i];
104                match naf_i.cmp(&0) {
105                    core::cmp::Ordering::Greater => {
106                        r += lookup_table.get(naf_i.unsigned_abs().into());
107                    }
108                    core::cmp::Ordering::Less => {
109                        r -= lookup_table.get(naf_i.unsigned_abs().into());
110                    }
111                    core::cmp::Ordering::Equal => {}
112                }
113            }
114        }
115        r
116    }
117}
118
119struct LookupTable<E: Curve>([Point<E>; 8]);
120
121impl<E: Curve> LookupTable<E> {
122    /// Builds a lookup table for point $P$
123    fn new(point: Point<E>) -> Self {
124        let mut table = [point; 8];
125        let point2 = point.double();
126        for i in 0..7 {
127            table[i + 1] = point2 + table[i];
128        }
129        Self(table)
130    }
131    /// Takes odd integer $x$ such as $0 < x < 2^4$, returns $x P$
132    fn get(&self, x: usize) -> Point<E> {
133        debug_assert_eq!(x & 1, 1);
134        debug_assert!(x < 16);
135
136        self.0[x / 2]
137    }
138}
139
140/// Stores a width-$w$ "Non-Adjacent Form" (NAF) of multiple scalars
141///
142/// Width-$w$ NAF represents an integer $k$ via coefficients $k_0, \dots, k_n$ such as:
143///
144/// $$k = \sum_{i=0}^{n} k_i \cdot 2^i$$
145///
146/// where each $k_i$ is odd and lies within range $-2^{w-1} \le k_i < 2^{w-1}$.
147///
148/// Non Adjacent Form allows us to reduce size of tables we need to precompute in Straus
149/// multiscalar multiplication by factor of 4.
150struct NafMatrix<E: Curve> {
151    /// Size of one scalar in non adjacent form
152    naf_size: usize,
153    /// Input parameter `w`
154    w: usize,
155    /// width = 2^w
156    width: u64,
157    /// width_half = width / 2
158    width_half: u64,
159    /// window_mask = width - 1
160    window_mask: u64,
161    matrix: Vec<i8>,
162
163    _curve: core::marker::PhantomData<E>,
164}
165
166impl<E: Curve> NafMatrix<E> {
167    /// Construct a new matrix with parameter `w`
168    ///
169    /// Preallocates memory to fit `capacity` amount of scalars
170    fn new(w: usize, capacity: usize) -> Self {
171        assert!((2..=8).contains(&w));
172        let naf_size = Scalar::<E>::serialized_len() * 8 + 1;
173        let width = 1 << w;
174
175        Self {
176            naf_size,
177            w,
178            width,
179            width_half: 1 << (w - 1),
180            matrix: Vec::with_capacity(naf_size * capacity),
181            window_mask: width - 1,
182            _curve: Default::default(),
183        }
184    }
185    /// Adds a scalar into matrix
186    fn add_scalar(&mut self, scalar: &Scalar<E>) {
187        let scalar_bytes = scalar.to_le_bytes();
188        let mut x_u64 = vec![0u64; scalar_bytes.len() / 8 + 1];
189        read_le_u64_into(&scalar_bytes, &mut x_u64[0..4]);
190
191        let offset = self.matrix.len();
192        debug_assert!(
193            offset + self.naf_size <= self.matrix.capacity(),
194            "unnecessary allocations detected"
195        );
196        self.matrix.resize(offset + self.naf_size, 0i8);
197        let naf = &mut self.matrix[offset..];
198
199        let mut pos = 0;
200        let mut carry = false;
201        while pos < self.naf_size {
202            let u64_idx = pos / 64;
203            let bit_idx = pos % 64;
204            let bit_buf: u64 = if bit_idx < 64 - self.w {
205                // This window bits are contained in a single u64
206                (x_u64[u64_idx] >> bit_idx) & self.window_mask
207            } else {
208                // Combine the current u64's bits with the bits from the next u64
209                ((x_u64[u64_idx] >> bit_idx) | (x_u64[u64_idx + 1] << (64 - bit_idx)))
210                    & self.window_mask
211            };
212
213            // Add the carry into the current window
214            let window = if carry { bit_buf + 1 } else { bit_buf };
215
216            if window & 1 == 0 {
217                // If the window value is even, preserve the carry and continue.
218                // Why is the carry preserved?
219                // If carry == 0 and window & 1 == 0, then the next carry should be 0
220                // If carry == 1 and window & 1 == 0, then bit_buf & 1 == 1 so the next carry should be 1
221                pos += 1;
222                continue;
223            }
224
225            if window < self.width_half {
226                carry = false;
227                naf[pos] = window as i8;
228            } else {
229                carry = true;
230                naf[pos] = (window as i8).wrapping_sub(self.width as i8);
231            }
232
233            pos += self.w;
234        }
235
236        debug_assert!(!carry);
237    }
238
239    /// Iterates over scalars NAF representations in the same order as
240    /// scalars were added into the matrix
241    fn iter(&self) -> impl Iterator<Item = &[i8]> {
242        self.matrix.chunks_exact(self.naf_size)
243    }
244}
245
246/// Read one or more u64s stored as little endian bytes.
247///
248/// ## Panics
249/// Panics if `src.len() != 8 * dst.len()`.
250fn read_le_u64_into(src: &[u8], dst: &mut [u64]) {
251    assert!(
252        src.len() == 8 * dst.len(),
253        "src.len() = {}, dst.len() = {}",
254        src.len(),
255        dst.len()
256    );
257    for (bytes, val) in src.chunks(8).zip(dst.iter_mut()) {
258        *val = u64::from_le_bytes(
259            #[allow(clippy::expect_used)]
260            bytes
261                .try_into()
262                .expect("Incorrect src length, should be 8 * dst.len()"),
263        );
264    }
265}
266
267#[cfg(test)]
268#[generic_tests::define]
269mod tests {
270    use alloc::vec::Vec;
271    use core::iter;
272
273    use crate::{Curve, Point, Scalar};
274
275    #[test]
276    fn non_adjacent_form_is_correct<E: Curve>() {
277        let mut rng = rand_dev::DevRng::new();
278
279        let scalars = iter::once(Scalar::<E>::zero())
280            .chain(iter::once(Scalar::one()))
281            .chain(iter::once(-Scalar::one()))
282            .chain(iter::repeat_with(|| Scalar::random(&mut rng)).take(15))
283            .collect::<Vec<_>>();
284
285        for w in 2..=8 {
286            let mut nafs = super::NafMatrix::new(w, scalars.len());
287            scalars.iter().for_each(|scalar| nafs.add_scalar(scalar));
288
289            for (scalar, naf) in scalars.iter().zip(nafs.iter()) {
290                std::eprintln!("scalar {scalar:?}");
291                std::eprintln!("naf: {naf:?}");
292
293                assert!(naf.iter().all(|&k_i| -(1i16 << (w - 1)) <= i16::from(k_i)
294                    && i16::from(k_i) < (1i16 << (w - 1))));
295
296                let expected = naf.iter().rev().fold(Scalar::<E>::zero(), |acc, naf_i| {
297                    acc + acc + Scalar::from(*naf_i)
298                });
299                assert_eq!(*scalar, expected)
300            }
301        }
302    }
303
304    #[test]
305    fn lookup_table<E: Curve>() {
306        let mut rng = rand_dev::DevRng::new();
307
308        let points = iter::once(Point::<E>::generator().to_point())
309            .chain(iter::repeat_with(|| Scalar::random(&mut rng) * Point::generator()).take(50));
310        for point in points {
311            let table = super::LookupTable::new(point);
312
313            for x in (1..16).step_by(2) {
314                assert_eq!(table.get(x), point * Scalar::from(x));
315            }
316        }
317    }
318
319    #[instantiate_tests(<crate::curves::Secp256k1>)]
320    mod secp256k1 {}
321    #[instantiate_tests(<crate::curves::Secp256r1>)]
322    mod secp256r1 {}
323    #[instantiate_tests(<crate::curves::Stark>)]
324    mod stark {}
325    #[instantiate_tests(<crate::curves::Ed25519>)]
326    mod ed25519 {}
327}