use alloc::{vec, vec::Vec};
use core::iter;
use crate::{Curve, Point, Scalar};
pub struct Straus;
impl<E: Curve> super::MultiscalarMul<E> for Straus {
fn multiscalar_mul<S, P>(
scalar_points: impl ExactSizeIterator<Item = (S, P)>,
) -> crate::Point<E>
where
S: AsRef<Scalar<E>>,
P: AsRef<Point<E>>,
{
let mut nafs = NafMatrix::new(5, scalar_points.len());
let lookup_tables: Vec<_> = scalar_points
.into_iter()
.map(|(scalar, point)| {
nafs.add_scalar(scalar.as_ref());
point
})
.map(|point| LookupTable::new(*point.as_ref()))
.collect();
if lookup_tables.is_empty() {
return Point::zero();
}
let naf_size = nafs.naf_size;
let mut r = Point::zero();
for (i, is_first_iter) in (0..naf_size)
.rev()
.zip(iter::once(true).chain(iter::repeat(false)))
{
if !is_first_iter {
r = r.double();
}
for (naf, lookup_table) in nafs.iter().zip(&lookup_tables) {
let naf_i = naf[i];
match naf_i.cmp(&0) {
core::cmp::Ordering::Greater => {
r += lookup_table.get(naf_i.unsigned_abs().into());
}
core::cmp::Ordering::Less => {
r -= lookup_table.get(naf_i.unsigned_abs().into());
}
core::cmp::Ordering::Equal => {}
}
}
}
r
}
}
struct LookupTable<E: Curve>([Point<E>; 8]);
impl<E: Curve> LookupTable<E> {
fn new(point: Point<E>) -> Self {
let mut table = [point; 8];
let point2 = point.double();
for i in 0..7 {
table[i + 1] = point2 + table[i];
}
Self(table)
}
fn get(&self, x: usize) -> Point<E> {
debug_assert_eq!(x & 1, 1);
debug_assert!(x < 16);
self.0[x / 2]
}
}
struct NafMatrix<E: Curve> {
naf_size: usize,
w: usize,
width: u64,
width_half: u64,
window_mask: u64,
matrix: Vec<i8>,
_curve: core::marker::PhantomData<E>,
}
impl<E: Curve> NafMatrix<E> {
fn new(w: usize, capacity: usize) -> Self {
assert!((2..=8).contains(&w));
let naf_size = Scalar::<E>::serialized_len() * 8 + 1;
let width = 1 << w;
Self {
naf_size,
w,
width,
width_half: 1 << (w - 1),
matrix: Vec::with_capacity(naf_size * capacity),
window_mask: width - 1,
_curve: Default::default(),
}
}
fn add_scalar(&mut self, scalar: &Scalar<E>) {
let scalar_bytes = scalar.to_le_bytes();
let mut x_u64 = vec![0u64; scalar_bytes.len() / 8 + 1];
read_le_u64_into(&scalar_bytes, &mut x_u64[0..scalar_bytes.len() / 8]);
let offset = self.matrix.len();
debug_assert!(
offset + self.naf_size <= self.matrix.capacity(),
"unnecessary allocations detected"
);
self.matrix.resize(offset + self.naf_size, 0i8);
let naf = &mut self.matrix[offset..];
let mut pos = 0;
let mut carry = false;
while pos < self.naf_size {
let u64_idx = pos / 64;
let bit_idx = pos % 64;
let bit_buf: u64 = if bit_idx < 64 - self.w {
(x_u64[u64_idx] >> bit_idx) & self.window_mask
} else {
((x_u64[u64_idx] >> bit_idx) | (x_u64[u64_idx + 1] << (64 - bit_idx)))
& self.window_mask
};
let window = if carry { bit_buf + 1 } else { bit_buf };
if window & 1 == 0 {
pos += 1;
continue;
}
if window < self.width_half {
carry = false;
naf[pos] = window as i8;
} else {
carry = true;
naf[pos] = (window as i8).wrapping_sub(self.width as i8);
}
pos += self.w;
}
debug_assert!(!carry);
}
fn iter(&self) -> impl Iterator<Item = &[i8]> {
self.matrix.chunks_exact(self.naf_size)
}
}
fn read_le_u64_into(src: &[u8], dst: &mut [u64]) {
assert!(
src.len() == 8 * dst.len(),
"src.len() = {}, dst.len() = {}",
src.len(),
dst.len()
);
for (bytes, val) in src.chunks(8).zip(dst.iter_mut()) {
*val = u64::from_le_bytes(
#[allow(clippy::expect_used)]
bytes
.try_into()
.expect("Incorrect src length, should be 8 * dst.len()"),
);
}
}
#[cfg(test)]
#[generic_tests::define]
mod tests {
use alloc::vec::Vec;
use core::iter;
use crate::{Curve, Point, Scalar};
#[test]
fn non_adjacent_form_is_correct<E: Curve>() {
let mut rng = rand_dev::DevRng::new();
let scalars = iter::once(Scalar::<E>::zero())
.chain(iter::once(Scalar::one()))
.chain(iter::once(-Scalar::one()))
.chain(iter::repeat_with(|| Scalar::random(&mut rng)).take(15))
.collect::<Vec<_>>();
for w in 2..=8 {
let mut nafs = super::NafMatrix::new(w, scalars.len());
scalars.iter().for_each(|scalar| nafs.add_scalar(scalar));
for (scalar, naf) in scalars.iter().zip(nafs.iter()) {
std::eprintln!("scalar {scalar:?}");
std::eprintln!("naf: {naf:?}");
assert!(naf.iter().all(|&k_i| -(1i16 << (w - 1)) <= i16::from(k_i)
&& i16::from(k_i) < (1i16 << (w - 1))));
let expected = naf.iter().rev().fold(Scalar::<E>::zero(), |acc, naf_i| {
acc + acc + Scalar::from(*naf_i)
});
assert_eq!(*scalar, expected)
}
}
}
#[test]
fn lookup_table<E: Curve>() {
let mut rng = rand_dev::DevRng::new();
let points = iter::once(Point::<E>::generator().to_point())
.chain(iter::repeat_with(|| Scalar::random(&mut rng) * Point::generator()).take(50));
for point in points {
let table = super::LookupTable::new(point);
for x in (1..16).step_by(2) {
assert_eq!(table.get(x), point * Scalar::from(x));
}
}
}
#[instantiate_tests(<crate::curves::Secp256k1>)]
mod secp256k1 {}
#[instantiate_tests(<crate::curves::Secp256r1>)]
mod secp256r1 {}
#[instantiate_tests(<crate::curves::Secp384r1>)]
mod secp384r1 {}
#[instantiate_tests(<crate::curves::Stark>)]
mod stark {}
#[instantiate_tests(<crate::curves::Ed25519>)]
mod ed25519 {}
}