use curve25519_dalek::{ristretto::RistrettoPoint, scalar::Scalar, traits::Identity};
use rand::rng;
use crate::{multiexp_num_buckets, multiexp_num_windows};
pub(crate) struct BinData {
pub sorted_scalar_ids: Vec<u32>,
pub bin_ids: Vec<u32>,
pub bin_counts: Vec<u32>,
pub bin_start_idx: Vec<u32>,
}
const LIMBS_PER_SCALAR: u32 = 8;
fn get_scalar_window(
scalar: &Scalar,
window_bits: u32, window_id: u32,
) -> u32 {
const BITS_PER_LIMB: u32 = 8 * std::mem::size_of::<u32>() as u32;
let window_start_idx = window_bits * window_id;
let limb_id_1 = window_start_idx / BITS_PER_LIMB;
let limbs = bytemuck::cast_slice::<_, u32>(scalar.as_bytes());
let limb_1 = limbs[limb_id_1 as usize];
let lo_mask = if window_bits < 32 {
(0x1 << window_bits) - 1
} else {
0xFFFFFFFF
};
let mut window = (limb_1 >> (window_start_idx % BITS_PER_LIMB)) & lo_mask;
let limb_boundary: u32 = (limb_id_1 + 1) * BITS_PER_LIMB;
if window_bits + window_start_idx > limb_boundary && limb_id_1 < LIMBS_PER_SCALAR - 1 {
let limb_id_2 = limb_id_1 + 1;
let limb_2 = limbs[limb_id_2 as usize];
let bits_remaining = window_start_idx + window_bits - limb_boundary;
let hi_mask = (0x1 << bits_remaining) - 1;
window |= (limb_2 & hi_mask) << (window_bits - bits_remaining);
}
window
}
fn rle(data: &[u32]) -> (Vec<u32>, Vec<u32>) {
if data.is_empty() {
return (vec![], vec![]);
}
let mut prev = data[0];
let mut count = 1;
let mut vals = vec![prev];
let mut runs = vec![];
for val in data.iter().skip(1).cloned() {
if val != prev {
vals.push(val);
runs.push(count);
prev = val;
count = 1;
} else {
count += 1;
}
}
runs.push(count);
assert_eq!(vals.len(), runs.len());
(vals, runs)
}
fn prefix_sum(x: &[u32]) -> Vec<u32> {
if x.is_empty() {
return vec![];
}
let mut sum = vec![0];
for (i, val) in x[0..(x.len() - 1)].iter().enumerate() {
sum.push(sum[i] + val);
}
sum
}
pub(crate) fn construct_bin_data(scalars: &[Scalar], window_bits: usize) -> Vec<BinData> {
const SCALAR_BIT_LEN: usize = 8 * std::mem::size_of::<Scalar>();
let num_windows = if SCALAR_BIT_LEN % window_bits == 0 {
SCALAR_BIT_LEN / window_bits
} else {
SCALAR_BIT_LEN / window_bits + 1
};
let mut bin_data = vec![];
for i in 0..num_windows {
let mut bins = scalars
.iter()
.enumerate()
.map(|x| (x.0, get_scalar_window(x.1, window_bits as u32, i as u32)))
.collect::<Vec<_>>();
bins.sort_by(|x, y| x.1.cmp(&y.1));
let sorted_scalar_ids = bins.iter().map(|x| x.0 as u32).collect::<Vec<_>>();
let sorted_bin_ids = bins.iter().map(|x| x.1).collect::<Vec<_>>();
let rle_bins = rle(&sorted_bin_ids);
let rle_sum = prefix_sum(&rle_bins.1);
let mut bin_offset_lengths = rle_bins
.1
.iter()
.zip(rle_sum.iter())
.zip(&rle_bins.0)
.collect::<Vec<_>>();
bin_offset_lengths.sort_by(|x, y| x.0 .0.cmp(y.0 .0));
let sorted_bin_counts = bin_offset_lengths
.iter()
.map(|x| *x.0 .0)
.collect::<Vec<_>>();
let sorted_bin_offsets = bin_offset_lengths
.iter()
.map(|x| *x.0 .1)
.collect::<Vec<_>>();
let sorted_bin_ids = bin_offset_lengths.iter().map(|x| *x.1).collect::<Vec<_>>();
bin_data.push(BinData {
sorted_scalar_ids,
bin_start_idx: sorted_bin_offsets,
bin_ids: sorted_bin_ids,
bin_counts: sorted_bin_counts,
});
}
bin_data
}
pub(crate) fn compute_bucket_points(
scalars: &[Scalar],
points: &[RistrettoPoint],
window_size: usize,
) -> Vec<Vec<RistrettoPoint>> {
assert_eq!(scalars.len(), points.len());
let num_buckets = multiexp_num_buckets(window_size);
let num_windows = multiexp_num_windows(window_size);
let mut window_bucket_points = vec![];
for window_id in 0..num_windows {
let mut written = vec![false; num_buckets];
let mut bucket_points = vec![RistrettoPoint::identity(); num_buckets];
for (i, s) in scalars.iter().enumerate() {
let cur_window = get_scalar_window(s, window_size as u32, window_id as u32);
assert!((cur_window as usize) < num_buckets);
if cur_window != 0 {
if !written[cur_window as usize] {
bucket_points[cur_window as usize] = points[i];
written[cur_window as usize] = true;
} else {
bucket_points[cur_window as usize] += points[i];
}
}
}
window_bucket_points.push(bucket_points);
}
assert_eq!(window_bucket_points.len(), num_windows);
window_bucket_points
}
#[test]
fn test_impl_get_scalar_window() {
let expected = Scalar::random(&mut rng());
for window_size in 10..33 {
let mut windows = vec![];
const SCALAR_BITS: usize = 8 * std::mem::size_of::<Scalar>();
let num_windows = if SCALAR_BITS % window_size == 0 {
SCALAR_BITS / window_size
} else {
SCALAR_BITS / window_size + 1
};
for window_id in 0..num_windows as u32 {
windows.push(get_scalar_window(&expected, window_size as u32, window_id));
}
let mut actual = Scalar::zero();
let mut radix = Scalar::one();
for window in windows.iter() {
assert!((*window as u64) < (0x1u64 << window_size as u64));
actual += Scalar::from(*window) * radix;
radix *= Scalar::from(0x1u64 << window_size as u64);
}
assert_eq!(actual, expected);
}
}
pub(crate) struct PrefixSumBlockRistretto {
pub block_sums: Vec<RistrettoPoint>,
pub block_totals: Vec<RistrettoPoint>,
}
pub(crate) fn prefix_sum_blocks_ristretto(
points: &[RistrettoPoint],
block_size: usize,
) -> PrefixSumBlockRistretto {
let block_totals = points
.chunks(block_size)
.map(|c| c.iter().fold(RistrettoPoint::identity(), |s, x| s + x))
.collect::<Vec<_>>();
let block_sums = points
.chunks(block_size)
.map(|c| {
let mut sum = RistrettoPoint::identity();
let mut t;
let mut out = c.to_owned();
for val in out.iter_mut() {
t = *val;
*val = sum;
sum += t;
}
out
})
.collect::<Vec<_>>()
.concat();
let expected_total_len = if points.len() % block_size == 0 {
points.len() / block_size
} else {
points.len() / block_size + 1
};
assert_eq!(block_totals.len(), expected_total_len);
assert_eq!(block_sums.len(), points.len());
PrefixSumBlockRistretto {
block_sums,
block_totals,
}
}
#[test]
fn test_rle_impl_works() {
let (vals, runs) = rle(&[1, 1, 1, 2, 2, 3, 4, 4, 4, 4, 5, 5, 7, 7, 7]);
assert_eq!(vals, vec![1, 2, 3, 4, 5, 7]);
assert_eq!(runs, vec![3, 2, 1, 4, 2, 3]);
}
#[test]
fn test_prefix_sum_works() {
let sum = prefix_sum(&[1, 3, 5, 7, 8, 11]);
assert_eq!(sum, vec![0, 1, 4, 9, 16, 24]);
}