#![cfg(not(target_arch = "bpf"))]
use {
curve25519_dalek::{ristretto::RistrettoPoint, scalar::Scalar, traits::Identity},
serde::{Deserialize, Serialize},
std::collections::HashMap,
};
const TWO15: u32 = 32768;
const TWO14: u32 = 16384; const TWO18: u32 = 262144;
#[derive(Serialize, Deserialize, Copy, Clone, Debug, Eq, PartialEq)]
pub struct DiscreteLog {
pub generator: RistrettoPoint,
pub target: RistrettoPoint,
}
#[derive(Serialize, Deserialize, Default)]
pub struct DecodeU32Precomputation(HashMap<[u8; 32], u32>);
fn decode_u32_precomputation(generator: RistrettoPoint) -> DecodeU32Precomputation {
let mut hashmap = HashMap::new();
let two12_scalar = Scalar::from(TWO14);
let identity = RistrettoPoint::identity(); let generator = two12_scalar * generator;
let ristretto_iter = RistrettoIterator::new(identity, generator);
let mut steps_for_breakpoint = 0;
ristretto_iter.zip(0..TWO18).for_each(|(elem, x_hi)| {
let key = elem.compress().to_bytes();
hashmap.insert(key, x_hi);
if x_hi % TWO15 == 0 {
println!(" [{:?}/8] completed", steps_for_breakpoint);
steps_for_breakpoint += 1;
}
});
println!(" [8/8] completed");
DecodeU32Precomputation(hashmap)
}
lazy_static::lazy_static! {
pub static ref DECODE_U32_PRECOMPUTATION_FOR_G: DecodeU32Precomputation = {
static DECODE_U32_PRECOMPUTATION_FOR_G_BINCODE: &[u8] =
include_bytes!("decode_u32_precomputation_for_G.bincode");
bincode::deserialize(DECODE_U32_PRECOMPUTATION_FOR_G_BINCODE).unwrap_or_default()
};
}
impl DiscreteLog {
pub(crate) fn decode_u32(self) -> Option<u32> {
self.decode_u32_online(&decode_u32_precomputation(self.generator))
}
pub fn decode_u32_online(self, hashmap: &DecodeU32Precomputation) -> Option<u32> {
let ristretto_iter = RistrettoIterator::new(self.target, -self.generator);
let mut decoded = None;
ristretto_iter.zip(0..TWO14).for_each(|(elem, x_lo)| {
let key = elem.compress().to_bytes();
if hashmap.0.contains_key(&key) {
let x_hi = hashmap.0[&key];
decoded = Some(x_lo + TWO14 * x_hi);
}
});
decoded
}
}
struct RistrettoIterator {
pub curr: RistrettoPoint,
pub step: RistrettoPoint,
}
impl RistrettoIterator {
fn new(curr: RistrettoPoint, step: RistrettoPoint) -> Self {
RistrettoIterator { curr, step }
}
}
impl Iterator for RistrettoIterator {
type Item = RistrettoPoint;
fn next(&mut self) -> Option<Self::Item> {
let r = self.curr;
self.curr += self.step;
Some(r)
}
}
#[cfg(test)]
mod tests {
use {super::*, curve25519_dalek::constants::RISTRETTO_BASEPOINT_POINT as G};
#[test]
#[allow(non_snake_case)]
fn test_serialize_decode_u32_precomputation_for_G() {
let decode_u32_precomputation_for_G = decode_u32_precomputation(G);
if decode_u32_precomputation_for_G.0 != DECODE_U32_PRECOMPUTATION_FOR_G.0 {
use std::{fs::File, io::Write, path::PathBuf};
let mut f = File::create(PathBuf::from(
"src/encryption/decode_u32_precomputation_for_G.bincode",
))
.unwrap();
f.write_all(&bincode::serialize(&decode_u32_precomputation_for_G).unwrap())
.unwrap();
panic!("Rebuild and run this test again");
}
}
#[test]
fn test_decode_correctness() {
let amount: u32 = 65545;
let instance = DiscreteLog {
generator: G,
target: Scalar::from(amount) * G,
};
let start_precomputation = time::precise_time_s();
let precomputed_hashmap = decode_u32_precomputation(G);
let end_precomputation = time::precise_time_s();
let start_online = time::precise_time_s();
let computed_amount = instance.decode_u32_online(&precomputed_hashmap).unwrap();
let end_online = time::precise_time_s();
assert_eq!(amount, computed_amount);
println!(
"16/16 Split precomputation: {:?} sec",
end_precomputation - start_precomputation
);
println!(
"16/16 Split online computation: {:?} sec",
end_online - start_online
);
}
}