use mohan::dalek::scalar::Scalar;
use crate::proofs::bulletproofs::inner_product_proof::inner_product;
pub struct VecPoly1(pub Vec<Scalar>, pub Vec<Scalar>);
pub struct VecPoly3(
pub Vec<Scalar>,
pub Vec<Scalar>,
pub Vec<Scalar>,
pub Vec<Scalar>,
);
pub struct Poly2(pub Scalar, pub Scalar, pub Scalar);
pub struct Poly6 {
pub t1: Scalar,
pub t2: Scalar,
pub t3: Scalar,
pub t4: Scalar,
pub t5: Scalar,
pub t6: Scalar,
}
pub struct ScalarExp {
x: Scalar,
next_exp_x: Scalar,
}
impl Iterator for ScalarExp {
type Item = Scalar;
fn next(&mut self) -> Option<Scalar> {
let exp_x = self.next_exp_x;
self.next_exp_x *= self.x;
Some(exp_x)
}
fn size_hint(&self) -> (usize, Option<usize>) {
(usize::max_value(), None)
}
}
pub fn exp_iter(x: Scalar) -> ScalarExp {
let next_exp_x = Scalar::one();
ScalarExp { x, next_exp_x }
}
pub fn add_vec(a: &[Scalar], b: &[Scalar]) -> Vec<Scalar> {
if a.len() != b.len() {
println!("lengths of vectors don't match for vector addition");
}
let mut out = vec![Scalar::zero(); b.len()];
for i in 0..a.len() {
out[i] = a[i] + b[i];
}
out
}
impl VecPoly1 {
pub fn zero(n: usize) -> Self {
VecPoly1(vec![Scalar::zero(); n], vec![Scalar::zero(); n])
}
pub fn inner_product(&self, rhs: &VecPoly1) -> Poly2 {
let l = self;
let r = rhs;
let t0 = inner_product(&l.0, &r.0);
let t2 = inner_product(&l.1, &r.1);
let l0_plus_l1 = add_vec(&l.0, &l.1);
let r0_plus_r1 = add_vec(&r.0, &r.1);
let t1 = inner_product(&l0_plus_l1, &r0_plus_r1) - t0 - t2;
Poly2(t0, t1, t2)
}
pub fn eval(&self, x: Scalar) -> Vec<Scalar> {
let n = self.0.len();
let mut out = vec![Scalar::zero(); n];
for i in 0..n {
out[i] = self.0[i] + self.1[i] * x;
}
out
}
}
impl VecPoly3 {
pub fn zero(n: usize) -> Self {
VecPoly3(
vec![Scalar::zero(); n],
vec![Scalar::zero(); n],
vec![Scalar::zero(); n],
vec![Scalar::zero(); n],
)
}
pub fn special_inner_product(lhs: &Self, rhs: &Self) -> Poly6 {
let t1 = inner_product(&lhs.1, &rhs.0);
let t2 = inner_product(&lhs.1, &rhs.1) + inner_product(&lhs.2, &rhs.0);
let t3 = inner_product(&lhs.2, &rhs.1) + inner_product(&lhs.3, &rhs.0);
let t4 = inner_product(&lhs.1, &rhs.3) + inner_product(&lhs.3, &rhs.1);
let t5 = inner_product(&lhs.2, &rhs.3);
let t6 = inner_product(&lhs.3, &rhs.3);
Poly6 {
t1,
t2,
t3,
t4,
t5,
t6,
}
}
pub fn eval(&self, x: Scalar) -> Vec<Scalar> {
let n = self.0.len();
let mut out = vec![Scalar::zero(); n];
for i in 0..n {
out[i] = self.0[i] + x * (self.1[i] + x * (self.2[i] + x * self.3[i]));
}
out
}
}
impl Poly2 {
pub fn eval(&self, x: Scalar) -> Scalar {
self.0 + x * (self.1 + x * self.2)
}
}
impl Poly6 {
pub fn eval(&self, x: Scalar) -> Scalar {
x * (self.t1 + x * (self.t2 + x * (self.t3 + x * (self.t4 + x * (self.t5 + x * self.t6)))))
}
}
impl Drop for VecPoly1 {
fn drop(&mut self) {
for e in self.0.iter_mut() {
mohan::zeroize_hack(e);
}
for e in self.1.iter_mut() {
mohan::zeroize_hack(e);
}
}
}
impl Drop for Poly2 {
fn drop(&mut self) {
mohan::zeroize_hack(&mut self.0);
mohan::zeroize_hack(&mut self.1);
mohan::zeroize_hack(&mut self.2);
}
}
impl Drop for VecPoly3 {
fn drop(&mut self) {
for e in self.0.iter_mut() {
mohan::zeroize_hack(e);
}
for e in self.1.iter_mut() {
mohan::zeroize_hack(e);
}
for e in self.2.iter_mut() {
mohan::zeroize_hack(e);
}
for e in self.3.iter_mut() {
mohan::zeroize_hack(e);
}
}
}
impl Drop for Poly6 {
fn drop(&mut self) {
mohan::zeroize_hack(&mut self.t1);
mohan::zeroize_hack(&mut self.t2);
mohan::zeroize_hack(&mut self.t3);
mohan::zeroize_hack(&mut self.t4);
mohan::zeroize_hack(&mut self.t5);
mohan::zeroize_hack(&mut self.t6);
}
}
pub fn scalar_exp_vartime(x: &Scalar, mut n: u64) -> Scalar {
let mut result = Scalar::one();
let mut aux = *x; while n > 0 {
let bit = n & 1;
if bit == 1 {
result = result * aux;
}
n = n >> 1;
aux = aux * aux; }
result
}
pub fn sum_of_powers(x: &Scalar, n: usize) -> Scalar {
if !n.is_power_of_two() {
return sum_of_powers_slow(x, n);
}
if n == 0 || n == 1 {
return Scalar::from(n as u64);
}
let mut m = n;
let mut result = Scalar::one() + x;
let mut factor = *x;
while m > 2 {
factor = factor * factor;
result = result + factor * result;
m = m / 2;
}
result
}
fn sum_of_powers_slow(x: &Scalar, n: usize) -> Scalar {
exp_iter(*x).take(n).sum()
}
pub fn read32(data: &[u8]) -> [u8; 32] {
let mut buf32 = [0u8; 32];
buf32[..].copy_from_slice(&data[..32]);
buf32
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn exp_2_is_powers_of_2() {
let exp_2: Vec<_> = exp_iter(Scalar::from(2u64)).take(4).collect();
assert_eq!(exp_2[0], Scalar::from(1u64));
assert_eq!(exp_2[1], Scalar::from(2u64));
assert_eq!(exp_2[2], Scalar::from(4u64));
assert_eq!(exp_2[3], Scalar::from(8u64));
}
#[test]
fn test_inner_product() {
let a = vec![
Scalar::from(1u64),
Scalar::from(2u64),
Scalar::from(3u64),
Scalar::from(4u64),
];
let b = vec![
Scalar::from(2u64),
Scalar::from(3u64),
Scalar::from(4u64),
Scalar::from(5u64),
];
assert_eq!(Scalar::from(40u64), inner_product(&a, &b));
}
fn scalar_exp_vartime_slow(x: &Scalar, n: u64) -> Scalar {
let mut result = Scalar::one();
for _ in 0..n {
result = result * x;
}
result
}
#[test]
fn test_scalar_exp() {
let x = Scalar::from_bits(
*b"\x84\xfc\xbcOx\x12\xa0\x06\xd7\x91\xd9z:'\xdd\x1e!CE\xf7\xb1\xb9Vz\x810sD\x96\x85\xb5\x07",
);
assert_eq!(scalar_exp_vartime(&x, 0), Scalar::one());
assert_eq!(scalar_exp_vartime(&x, 1), x);
assert_eq!(scalar_exp_vartime(&x, 2), x * x);
assert_eq!(scalar_exp_vartime(&x, 3), x * x * x);
assert_eq!(scalar_exp_vartime(&x, 4), x * x * x * x);
assert_eq!(scalar_exp_vartime(&x, 5), x * x * x * x * x);
assert_eq!(scalar_exp_vartime(&x, 64), scalar_exp_vartime_slow(&x, 64));
assert_eq!(
scalar_exp_vartime(&x, 0b11001010),
scalar_exp_vartime_slow(&x, 0b11001010)
);
}
#[test]
fn test_sum_of_powers() {
let x = Scalar::from(10u64);
assert_eq!(sum_of_powers_slow(&x, 0), sum_of_powers(&x, 0));
assert_eq!(sum_of_powers_slow(&x, 1), sum_of_powers(&x, 1));
assert_eq!(sum_of_powers_slow(&x, 2), sum_of_powers(&x, 2));
assert_eq!(sum_of_powers_slow(&x, 4), sum_of_powers(&x, 4));
assert_eq!(sum_of_powers_slow(&x, 8), sum_of_powers(&x, 8));
assert_eq!(sum_of_powers_slow(&x, 16), sum_of_powers(&x, 16));
assert_eq!(sum_of_powers_slow(&x, 32), sum_of_powers(&x, 32));
assert_eq!(sum_of_powers_slow(&x, 64), sum_of_powers(&x, 64));
}
#[test]
fn test_sum_of_powers_slow() {
let x = Scalar::from(10u64);
assert_eq!(sum_of_powers_slow(&x, 0), Scalar::zero());
assert_eq!(sum_of_powers_slow(&x, 1), Scalar::one());
assert_eq!(sum_of_powers_slow(&x, 2), Scalar::from(11u64));
assert_eq!(sum_of_powers_slow(&x, 3), Scalar::from(111u64));
assert_eq!(sum_of_powers_slow(&x, 4), Scalar::from(1111u64));
assert_eq!(sum_of_powers_slow(&x, 5), Scalar::from(11111u64));
assert_eq!(sum_of_powers_slow(&x, 6), Scalar::from(111111u64));
}
}