use std::collections::HashMap;
use std::fmt::Debug;
use std::iter::successors;
use std::marker::PhantomData;
use std::ops::Add;
use crate::groups::multiplier::integer_utils::{get_bits_from_bytes, is_power_of_2, test_bit};
use crate::groups::multiplier::{integer_utils, ScalarMultiplier, ToLittleEndianBytes};
use crate::groups::Doubling;
pub struct WindowedScalarMultiplier<
G,
S,
const CACHE_SIZE: usize,
const SLIDING_WINDOW_WIDTH: usize,
> {
cache: [G; CACHE_SIZE],
_scalar: PhantomData<S>,
}
impl<G, S, const CACHE_SIZE: usize, const SLIDING_WINDOW_WIDTH: usize>
WindowedScalarMultiplier<G, S, CACHE_SIZE, SLIDING_WINDOW_WIDTH>
{
const WINDOW_WIDTH: usize = integer_utils::log2(CACHE_SIZE);
}
impl<
G: for<'a> Add<&'a G, Output = G> + Doubling + Clone + Debug,
S: ToLittleEndianBytes + Clone + Debug,
const CACHE_SIZE: usize,
const SLIDING_WINDOW_WIDTH: usize,
> ScalarMultiplier<G, S> for WindowedScalarMultiplier<G, S, CACHE_SIZE, SLIDING_WINDOW_WIDTH>
{
fn new(base_element: G, zero: G) -> Self {
if !is_power_of_2(CACHE_SIZE) || CACHE_SIZE <= 1 {
panic!("CACHE_SIZE must be a power of two greater than 1");
}
let mut cache = vec![];
cache.push(zero);
cache.push(base_element.clone());
for i in 2..CACHE_SIZE {
cache.push(cache[i - 1].clone() + &base_element);
}
let cache: [G; CACHE_SIZE] = cache.try_into().unwrap();
Self {
cache,
_scalar: PhantomData,
}
}
fn mul(&self, scalar: &S) -> G {
let scalar_bytes = scalar.to_le_bytes();
let base_2w_expansion =
integer_utils::compute_base_2w_expansion(&scalar_bytes, Self::WINDOW_WIDTH);
let mut result: G = self.cache[base_2w_expansion[base_2w_expansion.len() - 1]].clone();
for digit in base_2w_expansion.iter().rev().skip(1) {
for _ in 1..=Self::WINDOW_WIDTH {
result = result.double();
}
result = result + &self.cache[*digit];
}
result
}
fn two_scalar_mul(&self, base_scalar: &S, other_element: &G, other_scalar: &S) -> G {
multi_scalar_mul(
&[base_scalar.clone(), other_scalar.clone()],
&[self.cache[1].clone(), other_element.clone()],
&HashMap::from([(0, self.cache[CACHE_SIZE / 2..CACHE_SIZE].to_vec())]),
SLIDING_WINDOW_WIDTH,
self.cache[0].clone(),
)
}
}
pub fn multi_scalar_mul<
G: Doubling + for<'a> Add<&'a G, Output = G> + Clone + Debug,
S: ToLittleEndianBytes + Clone + Debug,
const N: usize,
>(
scalars: &[S; N],
elements: &[G; N],
precomputed_multiples: &HashMap<usize, Vec<G>>,
default_window_width: usize,
zero: G,
) -> G {
if N == 0 {
return zero;
}
let mut window_sizes = [0usize; N];
let mut missing_precomputations = HashMap::new();
for (i, element) in elements.iter().enumerate() {
if !precomputed_multiples.contains_key(&i) {
missing_precomputations.insert(i, compute_multiples(element, default_window_width));
}
}
let mut all_precomputed_multiples = vec![];
for i in 0..N {
match precomputed_multiples.get(&i).take() {
Some(precomputed_multiples) => {
all_precomputed_multiples.push(precomputed_multiples);
window_sizes[i] = integer_utils::log2(all_precomputed_multiples[i].len()) + 1;
}
None => {
all_precomputed_multiples.push(&missing_precomputations[&i]);
window_sizes[i] = default_window_width;
}
}
}
let scalar_bytes = scalars
.iter()
.map(|s| s.to_le_bytes())
.collect::<Vec<Vec<u8>>>();
let scalar_size = scalar_bytes
.iter()
.map(|b| b.len())
.max()
.expect("No scalars given.");
let mut is_in_window = [false; N];
let mut index_in_window = [0usize; N]; let mut precomputed_multiple_index = [0usize; N];
let mut is_zero = true;
let mut result = zero;
for bit in (0..scalar_size * 8).rev() {
if !is_zero {
result = result.double();
}
for i in 0..N {
if is_in_window[i] {
index_in_window[i] += 1;
if index_in_window[i] == window_sizes[i] {
result = if is_zero {
is_zero = false;
all_precomputed_multiples[i][precomputed_multiple_index[i]].clone()
} else {
result + &all_precomputed_multiples[i][precomputed_multiple_index[i]]
};
is_in_window[i] = false;
}
} else if test_bit(&scalar_bytes[i], bit) {
if bit >= window_sizes[i] - 1 {
is_in_window[i] = true;
index_in_window[i] = 1;
precomputed_multiple_index[i] = get_bits_from_bytes(
&scalar_bytes[i],
bit + 1 - window_sizes[i],
bit, );
} else {
result = if is_zero {
is_zero = false;
elements[i].clone()
} else {
result + &elements[i]
};
}
}
}
}
result
}
fn compute_multiples<G: Doubling + for<'a> Add<&'a G, Output = G> + Clone + Debug>(
base_element: &G,
window_size: usize,
) -> Vec<G> {
assert!(window_size > 0, "Window size must be strictly positive.");
let mut smallest_multiple = base_element.clone();
for _ in 1..window_size {
smallest_multiple = smallest_multiple.double();
}
successors(Some(smallest_multiple), |g| Some(g.clone() + base_element))
.take(1 << (window_size - 1))
.collect::<Vec<_>>()
}
#[cfg(test)]
mod tests {
use ark_ff::{BigInteger, PrimeField};
use ark_secp256r1::Fr;
use rand::thread_rng;
use crate::groups::ristretto255::{RistrettoPoint, RistrettoScalar};
use crate::groups::secp256r1::{ProjectivePoint, Scalar};
use crate::groups::GroupElement;
use crate::groups::Scalar as ScalarTrait;
use crate::serde_helpers::ToFromByteArray;
use super::*;
impl ToLittleEndianBytes for RistrettoScalar {
fn to_le_bytes(&self) -> Vec<u8> {
self.to_byte_array().to_vec()
}
}
#[test]
fn test_scalar_multiplication_ristretto() {
let multiplier = WindowedScalarMultiplier::<RistrettoPoint, RistrettoScalar, 16, 4>::new(
RistrettoPoint::generator(),
RistrettoPoint::zero(),
);
let scalars = [
RistrettoScalar::from(0),
RistrettoScalar::from(1),
RistrettoScalar::from(2),
RistrettoScalar::from(1234),
RistrettoScalar::from(123456),
RistrettoScalar::from(123456789),
RistrettoScalar::from(0xffffffffffffffff),
RistrettoScalar::group_order(),
RistrettoScalar::group_order() - RistrettoScalar::from(1),
RistrettoScalar::group_order() + RistrettoScalar::from(1),
];
for scalar in scalars {
let expected = RistrettoPoint::generator() * scalar;
let actual = multiplier.mul(&scalar);
assert_eq!(expected, actual);
}
}
#[test]
fn test_scalar_multiplication_secp256r1() {
let mut modulus_minus_one = Fr::MODULUS_MINUS_ONE_DIV_TWO;
modulus_minus_one.mul2();
let scalars = [
Scalar::from(0),
Scalar::from(1),
Scalar::from(2),
Scalar::from(1234),
Scalar::from(123456),
Scalar::from(123456789),
Scalar::from(0xffffffffffffffff),
Scalar(Fr::from(modulus_minus_one)),
];
for scalar in scalars {
let expected = ProjectivePoint::generator() * scalar;
let multiplier = WindowedScalarMultiplier::<ProjectivePoint, Scalar, 2, 4>::new(
ProjectivePoint::generator(),
ProjectivePoint::zero(),
);
let actual = multiplier.mul(&scalar);
assert_eq!(expected, actual);
let multiplier = WindowedScalarMultiplier::<ProjectivePoint, Scalar, 16, 4>::new(
ProjectivePoint::generator(),
ProjectivePoint::zero(),
);
let actual = multiplier.mul(&scalar);
assert_eq!(expected, actual);
let multiplier = WindowedScalarMultiplier::<ProjectivePoint, Scalar, 32, 4>::new(
ProjectivePoint::generator(),
ProjectivePoint::zero(),
);
let actual = multiplier.mul(&scalar);
assert_eq!(expected, actual);
let multiplier = WindowedScalarMultiplier::<ProjectivePoint, Scalar, 64, 4>::new(
ProjectivePoint::generator(),
ProjectivePoint::zero(),
);
let actual = multiplier.mul(&scalar);
assert_eq!(expected, actual);
let multiplier = WindowedScalarMultiplier::<ProjectivePoint, Scalar, 512, 4>::new(
ProjectivePoint::generator(),
ProjectivePoint::zero(),
);
let actual = multiplier.mul(&scalar);
assert_eq!(expected, actual);
}
}
#[test]
fn test_double_mul_ristretto() {
let multiplier = WindowedScalarMultiplier::<RistrettoPoint, RistrettoScalar, 16, 5>::new(
RistrettoPoint::generator(),
RistrettoPoint::zero(),
);
let other_point = RistrettoPoint::generator() * RistrettoScalar::from(3);
let a = RistrettoScalar::rand(&mut thread_rng());
let b = RistrettoScalar::rand(&mut thread_rng());
let expected = RistrettoPoint::generator() * a + other_point * b;
let actual = multiplier.two_scalar_mul(&a, &other_point, &b);
assert_eq!(expected, actual);
}
}