use crate::{
core::{
actually_used_field::ActuallyUsedField,
bounds::FieldBounds,
circuits::{
boolean::boolean_value::BooleanValue,
traits::arithmetic_circuit::ArithmeticCircuit,
},
expressions::expr::EvalFailure,
global_value::value::FieldValue,
},
traits::{Equal, GetBit, Pow, Select, WithBooleanBounds},
utils::number::Number,
};
use std::ops::{Add, Mul, Sub};
#[derive(Clone, Debug)]
pub struct SqrtCircuit;
pub fn sqrt<
F: ActuallyUsedField,
B: Select<T, T, T> + Copy,
T: GetBit<Output = B>
+ Equal<T, Output = B>
+ From<B>
+ Copy
+ Add<T, Output = T>
+ Sub<T, Output = T>
+ Mul<T, Output = T>
+ Pow
+ WithBooleanBounds
+ From<i32>
+ From<Number>
+ From<F>,
>(
a: T,
is_expected_non_zero: bool,
) -> (B, T) {
let p = F::modulus();
if &p % 8 == 1 {
panic!("p % 8 == 1 not supported");
};
if &p % 4 == 3 {
let base_pow = a.pow(&((&p - 3) / 4), is_expected_non_zero);
let pow = base_pow * a;
let is_square_field = pow * base_pow;
let is_square_bool =
T::from(1) - is_square_field * (is_square_field - T::from(1)) * T::from(F::TWO_INV);
let is_square = is_square_bool.with_boolean_bounds().get_bit(0, false);
(is_square, is_square.select(pow, T::from(0)))
} else {
let is_zero_mask = if !is_expected_non_zero {
T::from(a.eq(T::from(0)))
} else {
T::from(0)
};
let a = a + is_zero_mask;
let base_pow = a.pow(&((&p - 5) / 8), true);
let pow = base_pow * a;
let is_fourth_power_field = pow * base_pow;
let is_square_field = is_fourth_power_field * is_fourth_power_field;
let is_square_bool =
T::from(1) - is_square_field * (is_square_field - T::from(1)) * T::from(F::TWO_INV);
let is_square = is_square_bool.with_boolean_bounds().get_bit(0, false);
let is_fourth_power_bool = (is_fourth_power_field + T::from(1)) * T::from(F::TWO_INV);
let is_fourth_power = (is_square_bool * is_fourth_power_bool)
.with_boolean_bounds()
.get_bit(0, false);
(
is_square,
is_square.select(
pow * is_fourth_power
.select(T::from(1), T::from(F::from(2).pow(&((&p - 1) / 4), true))),
T::from(0),
) - is_zero_mask,
)
}
}
impl<F: ActuallyUsedField> ArithmeticCircuit<F> for SqrtCircuit {
fn eval(&self, x: Vec<F>) -> Result<Vec<F>, EvalFailure> {
Ok(x.into_iter()
.map(|val| sqrt::<F, bool, _>(val, false).1)
.collect())
}
fn bounds(&self, bounds: Vec<FieldBounds<F>>) -> Vec<FieldBounds<F>> {
vec![FieldBounds::All; bounds.len()]
}
fn run(&self, vals: Vec<FieldValue<F>>) -> Vec<FieldValue<F>> {
vals.into_iter()
.map(|val| sqrt::<F, BooleanValue, _>(val, false).1)
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
core::circuits::traits::arithmetic_circuit::tests::TestedArithmeticCircuit,
utils::{field::ScalarField, used_field::UsedField},
};
use ff::Field;
use rand::Rng;
use std::marker::PhantomData;
#[test]
fn test_sqrt() {
assert_eq!(
sqrt::<ScalarField, bool, _>(ScalarField::ZERO, false),
(true, ScalarField::ZERO)
);
assert_eq!(
sqrt::<ScalarField, bool, _>(ScalarField::ONE, false),
(true, ScalarField::ONE)
);
let rng = &mut crate::utils::test_rng::get();
for _ in 0..100 {
let x = ScalarField::random(&mut *rng);
let computed = sqrt::<ScalarField, bool, _>(x, false);
let sqrt_computed = (computed.0, computed.1.abs());
let expected = x.sqrt();
let is_square_expected = expected.is_some();
let sqrt_expected = if is_square_expected.into() {
(true, expected.unwrap().abs())
} else {
(false, ScalarField::ZERO)
};
assert_eq!(sqrt_computed, sqrt_expected);
}
}
impl<F: ActuallyUsedField> TestedArithmeticCircuit<F> for SqrtCircuit {
fn gen_desc<R: Rng + ?Sized>(_rng: &mut R) -> Self {
Self
}
fn gen_n_inputs<R: Rng + ?Sized>(&self, _rng: &mut R) -> usize {
1
}
fn extra_checks(&self, inputs: Vec<F>, outputs: Vec<F>) {
assert_eq!(inputs.len(), outputs.len());
inputs.into_iter().zip(outputs).for_each(|(val, output)| {
let p = F::modulus();
let is_square = val.pow(&((&p - 1) / 2), true);
let expected = if is_square == F::ONE { val } else { F::ZERO };
assert_eq!(output * output, expected);
})
}
}
#[test]
fn tested() {
SqrtCircuit::test_with_marker(1, 64, PhantomData::<ScalarField>)
}
}