use crate::{
core::{
actually_used_field::ActuallyUsedField,
bounds::FieldBounds,
circuits::traits::arithmetic_circuit::ArithmeticCircuit,
expressions::{expr::EvalFailure, field_expr::FieldExpr},
global_value::value::FieldValue,
},
traits::{Equal, Invert, Pow, Reveal},
utils::{number::Number, used_field::UsedField},
};
use std::ops::Mul;
#[derive(Clone, Debug)]
pub struct PowCircuit {
pub exponent: Number,
pub is_expected_non_zero: bool,
}
fn smart_pow<T: Mul<T, Output = T> + Clone + From<i32>>(a: T, b: &Number) -> T {
assert!(*b > 0);
let pows: Vec<T> = (0..b.bits())
.scan(a.clone(), |c, _| {
let res = Some(c.clone());
*c = (c.clone()) * (c.clone());
res
})
.collect();
let mut res = T::from(1);
for (i, pow) in pows.into_iter().enumerate() {
if b.bit(i) {
res = res * pow;
}
}
res
}
impl PowCircuit {
#[allow(unused)]
pub fn new(exponent: Number, is_expected_non_zero: bool) -> Self {
Self {
exponent,
is_expected_non_zero,
}
}
}
impl<F: UsedField> ArithmeticCircuit<F> for PowCircuit {
fn eval(&self, x: Vec<F>) -> Result<Vec<F>, EvalFailure> {
Ok(x.into_iter()
.map(|val| val.pow(&self.exponent, true))
.collect())
}
fn bounds(&self, bounds: Vec<FieldBounds<F>>) -> Vec<FieldBounds<F>> {
bounds
.into_iter()
.map(|bound| {
if self.exponent == 0 {
FieldBounds::new(F::ONE, F::ONE)
} else {
smart_pow(bound, &self.exponent)
}
})
.collect()
}
fn run(&self, vals: Vec<FieldValue<F>>) -> Vec<FieldValue<F>>
where
F: ActuallyUsedField,
{
assert!(vals.len() == 1);
let pow = if self.exponent == 0 {
FieldValue::<F>::from(1)
} else if vals[0].is_plaintext() {
FieldValue::<F>::new(FieldExpr::Pow(vals[0], self.exponent.clone(), false))
} else if self.exponent == F::get_alpha() && self.is_expected_non_zero {
let lambda = FieldValue::<F>::random();
let lambda_inv = lambda.invert(true);
let lambda_pow_neg_alpha = smart_pow(lambda_inv, &F::get_alpha());
let x = vals[0];
let z = (x * lambda).reveal();
let z_pow_alpha = FieldValue::<F>::new(FieldExpr::Pow(z, F::get_alpha(), true));
z_pow_alpha * lambda_pow_neg_alpha
} else if self.exponent == F::get_alpha_inverse() {
let lambda = FieldValue::<F>::random();
let lambda_inv = lambda.invert(true);
let lambda_pow_neg_alpha = smart_pow(lambda_inv, &F::get_alpha());
let x = vals[0];
let is_zero_mask = if !self.is_expected_non_zero {
FieldValue::from(x.eq(FieldValue::from(0)))
} else {
FieldValue::from(0)
};
let x = x + is_zero_mask;
let z = (x * lambda_pow_neg_alpha).reveal();
let z_pow_alpha_inverse =
FieldValue::<F>::new(FieldExpr::Pow(z, F::get_alpha_inverse(), true));
z_pow_alpha_inverse * lambda - is_zero_mask
} else if self.exponent == (F::modulus() - 1) / 2 {
let lambda = FieldValue::<F>::random();
let lambda_pow_exponent = smart_pow(lambda, &self.exponent);
let x = vals[0];
let is_zero_mask = if !self.is_expected_non_zero {
FieldValue::from(x.eq(FieldValue::from(0)))
} else {
FieldValue::from(0)
};
let x = x + is_zero_mask;
let z = (x * lambda).reveal();
let z_pow_exponent =
FieldValue::<F>::new(FieldExpr::Pow(z, self.exponent.clone(), true));
z_pow_exponent * lambda_pow_exponent - is_zero_mask
} else if F::modulus() % 4 == 3 && self.exponent == (F::modulus() - 3) / 4 {
let lambda = FieldValue::<F>::random();
let lambda_pow = smart_pow(lambda, &((3 * F::modulus() - 1) / 4));
let x = vals[0];
let is_zero_mask = if !self.is_expected_non_zero {
FieldValue::from(x.eq(FieldValue::from(0)))
} else {
FieldValue::from(0)
};
let x = x + is_zero_mask;
let z = (x * lambda).reveal();
let z_pow_exponent =
FieldValue::<F>::new(FieldExpr::Pow(z, self.exponent.clone(), true));
z_pow_exponent * lambda_pow - is_zero_mask
} else if F::modulus() % 8 == 5 && self.exponent == (F::modulus() - 5) / 8 {
let lambda = FieldValue::<F>::random();
let lambda_pow = smart_pow(lambda, &((7 * F::modulus() - 3) / 8));
let x = vals[0];
let is_zero_mask = if !self.is_expected_non_zero {
FieldValue::from(x.eq(FieldValue::from(0)))
} else {
FieldValue::from(0)
};
let x = x + is_zero_mask;
let z = (x * lambda).reveal();
let z_pow_exponent =
FieldValue::<F>::new(FieldExpr::Pow(z, self.exponent.clone(), true));
z_pow_exponent * lambda_pow - is_zero_mask
} else {
smart_pow(vals[0], &self.exponent)
};
vec![pow]
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
core::circuits::traits::arithmetic_circuit::tests::TestedArithmeticCircuit,
ArcisField,
};
use num_traits::ToPrimitive;
use rand::Rng;
use std::marker::PhantomData;
#[test]
fn test_smart_pow() {
let alpha = 5usize;
let alpha_number = Number::from(alpha);
for number in [Number::from(-1), 5.into(), (-3).into()] {
let tested = smart_pow(number.clone(), &alpha_number);
let expected = {
let mut res = Number::from(1);
for _ in 0..alpha {
res = res * &number
}
res
};
assert_eq!(tested, expected);
}
}
impl<F: ActuallyUsedField> TestedArithmeticCircuit<F> for PowCircuit {
fn gen_desc<R: Rng + ?Sized>(rng: &mut R) -> Self {
let is_expected_non_zero = rng.gen_bool(0.5);
if rng.gen_bool(0.25) {
return Self::new(0.into(), is_expected_non_zero);
}
let r = rng.next_u64() % 5;
match r {
0 => Self::new((rng.next_u64() % 256).into(), is_expected_non_zero),
1 => Self::new(F::get_alpha_inverse(), is_expected_non_zero),
2 => Self::new((F::modulus() - 1) / 2, is_expected_non_zero),
3 if F::modulus() % 4 == 3 => {
Self::new((F::modulus() - 3) / 4, is_expected_non_zero)
}
3 if F::modulus() % 8 == 5 => {
Self::new((F::modulus() - 5) / 8, is_expected_non_zero)
}
_ => Self::new(rng.next_u64().into(), is_expected_non_zero),
}
}
fn gen_n_inputs<R: Rng + ?Sized>(&self, _rng: &mut R) -> usize {
1
}
fn extra_checks(&self, inputs: Vec<F>, outputs: Vec<F>) {
assert_eq!(inputs.len(), outputs.len());
inputs.into_iter().zip(outputs).for_each(|(input, output)| {
if self.exponent == 0 {
assert_eq!(output, F::ONE);
} else if let Some(n) = self.exponent.to_u8() {
let mut res = input;
for _ in 1..n {
res *= input
}
assert_eq!(output, res);
}
})
}
}
#[test]
fn tested() {
PowCircuit::test_with_marker(64, 4, PhantomData::<ArcisField>)
}
}