use crate::{
core::{
actually_used_field::ActuallyUsedField,
bounds::FieldBounds,
circuits::{
boolean::{
boolean_value::BooleanValue,
utils::{icpot_signed, shift_right, CircuitType},
},
traits::arithmetic_circuit::ArithmeticCircuit,
},
expressions::expr::EvalFailure,
global_value::value::FieldValue,
},
traits::Select,
types::DOUBLE_PRECISION_MANTISSA,
utils::{number::Number, used_field::UsedField},
};
const PRECISION_MARGIN: usize = 4;
const CHEBYSHEV_COEFFS: [i64; 20] = [
60141202130508328,
-10357137487257478,
1334415195357729,
-190910445362612,
28671436609356,
-4428426634545,
696606235352,
-110996612595,
17855626742,
-2893586923,
471670337,
-77251781,
12702626,
-2095693,
346782,
-57414,
9559,
-1595,
296,
-53,
];
#[derive(Clone, Debug)]
pub struct DivSqrt {
precision_a: usize,
precision_b: usize,
precision_out: usize,
}
impl DivSqrt {
pub const fn new(precision_a: usize, precision_b: usize) -> Self {
if precision_a > DOUBLE_PRECISION_MANTISSA || precision_b > DOUBLE_PRECISION_MANTISSA {
panic!("input precision must be at most 52",);
}
DivSqrt {
precision_a,
precision_b,
precision_out: DOUBLE_PRECISION_MANTISSA,
}
}
fn init_inv_sqrt<F: ActuallyUsedField>(
&self,
b: FieldValue<F>,
) -> (FieldValue<F>, FieldValue<F>, BooleanValue) {
let (icpot, icpot_bits, is_non_pos) = icpot_signed(b, CircuitType::default());
let sqrt_two = F::from(2f64.sqrt());
let len = icpot_bits.clone().len();
let offset_sqrt_icpot = (len as i32 - self.precision_b as i32 + 1) / 2;
let precision_sqrt_icpot = offset_sqrt_icpot + self.precision_out as i32;
let sqrt_icpot = icpot_bits
.into_iter()
.rev()
.enumerate()
.fold(FieldValue::<F>::from(0), |acc, (i, bit)| {
acc + bit.select(
FieldValue::from(if (i as i32 - self.precision_b as i32) % 2 == 0 {
F::power_of_two(
(((self.precision_b as i32 - i as i32) / 2) + precision_sqrt_icpot)
as usize,
)
} else {
F::power_of_two(
((self.precision_b as i32 - i as i32 - 1) / 2 + offset_sqrt_icpot)
as usize,
) * sqrt_two
}),
FieldValue::<F>::from(0),
)
})
.with_bounds(FieldBounds::new(
if b.bounds().signed_min().is_le_zero() {
F::ZERO
} else {
F::power_of_two(self.precision_out)
},
F::power_of_two(len.div_ceil(2) + 1 + self.precision_out),
));
(sqrt_icpot, icpot, is_non_pos)
}
fn inv_sqrt_approx<F: ActuallyUsedField>(&self, b_normalized: FieldValue<F>) -> FieldValue<F> {
let one =
FieldValue::<F>::from(Number::power_of_two(self.precision_out + PRECISION_MARGIN));
let z = FieldValue::from(F::power_of_two(1 + PRECISION_MARGIN)) * b_normalized - 3 * one;
let mut chebyshev_polynomials = vec![one, z];
for i in 2..20 {
let last = chebyshev_polynomials[i - 1];
let second_last = chebyshev_polynomials[i - 2];
chebyshev_polynomials.push(
2 * shift_right(z * last, self.precision_out + PRECISION_MARGIN, true)
- second_last,
);
}
CHEBYSHEV_COEFFS
.into_iter()
.zip(chebyshev_polynomials)
.fold(FieldValue::<F>::from(0), |acc, (c, p)| {
acc + FieldValue::from(F::from(Number::from(c))) * p
})
>> (self.precision_out + 2 * PRECISION_MARGIN)
}
pub fn div_sqrt<F: ActuallyUsedField>(
&self,
a: FieldValue<F>,
b: FieldValue<F>,
) -> FieldValue<F> {
let a_bounds = a.bounds();
let b_bounds = b.bounds();
if b_bounds.signed_max().is_le_zero() {
FieldValue::<F>::from(0)
} else {
let (sqrt_icpot, icpot, _) = self.init_inv_sqrt(b);
let b_icpot = b * icpot;
let offset_icpot = b_bounds.bin_size(true) as i32 - self.precision_out as i32 - 2;
let b_normalized = if offset_icpot > 0 {
b_icpot >> (offset_icpot as usize)
} else {
FieldValue::from(F::power_of_two((-offset_icpot) as usize)) * b_icpot
};
let b_normalized_bounds = FieldBounds::new(
if b_bounds.signed_min().is_le_zero() {
F::ZERO
} else {
F::power_of_two(self.precision_out)
},
F::power_of_two(self.precision_out + 1) - F::ONE,
);
let inv_sqrt_b_normalized =
self.inv_sqrt_approx(b_normalized.with_bounds(b_normalized_bounds));
let offset_sqrt_icpot = (b_bounds.bin_size(true) as i32 - self.precision_b as i32) / 2;
let a_normalized = if a_bounds.bin_size(true) + sqrt_icpot.bounds().bin_size(false)
< F::NUM_BITS as usize
{
shift_right(
a * sqrt_icpot,
(offset_sqrt_icpot + self.precision_a as i32) as usize,
true,
)
} else if b_bounds == FieldBounds::All && a.get_id() != b.get_id() {
shift_right(
a * (sqrt_icpot >> offset_sqrt_icpot as usize),
self.precision_a,
true,
)
} else {
shift_right(
shift_right(a, offset_sqrt_icpot.max(0) as usize, true) * sqrt_icpot,
self.precision_a,
true,
)
};
shift_right(
a_normalized * inv_sqrt_b_normalized,
self.precision_out,
true,
)
.with_bounds(self.div_sqrt_bounds(a_bounds, b_bounds))
}
}
fn div_sqrt_public<F: UsedField>(&self, a: F, b: F) -> F {
let b_signed = b.to_signed_number();
if b_signed > 0 {
let a_signed = a.to_signed_number();
let a_float = f64::from(a_signed) * 2f64.powi(-(self.precision_a as i32));
let b_float = f64::from(b_signed) * 2f64.powi(-(self.precision_b as i32));
F::from(a_float / b_float.sqrt())
} else {
F::ZERO
}
}
fn div_sqrt_bounds<F: UsedField>(
&self,
a_bounds: FieldBounds<F>,
b_bounds: FieldBounds<F>,
) -> FieldBounds<F> {
if a_bounds.bin_size(true) + 2 * self.precision_out > F::NUM_BITS as usize
|| b_bounds.bin_size(true) + self.precision_out > F::NUM_BITS as usize
{
FieldBounds::All
} else {
let (a_min, a_max) = a_bounds.min_and_max(true);
let (b_min, b_max) = b_bounds.min_and_max(true);
let (min, max) = (
(self.div_sqrt_public(a_min, b_max) - self.eval_gap(&[a_min, b_max])).min(
self.div_sqrt_public(a_min, b_min.max(F::ONE, true))
- self.eval_gap(&[a_min, b_min.max(F::ONE, true)]),
true,
),
(self.div_sqrt_public(a_max, b_min.max(F::ONE, true))
+ self.eval_gap(&[a_max, b_min.max(F::ONE, true)]))
.max(
self.div_sqrt_public(a_max, b_max) + self.eval_gap(&[a_max, b_max]),
true,
),
);
if b_min.is_le_zero() {
FieldBounds::new(min.min(F::ZERO, true), max.max(F::ZERO, true))
} else {
FieldBounds::new(min, max)
}
}
}
}
impl<F: UsedField> ArithmeticCircuit<F> for DivSqrt {
fn eval(&self, x: Vec<F>) -> Result<Vec<F>, EvalFailure> {
if x.len() != 2 {
panic!("div_sqrt requires two inputs")
}
let a = x[0];
let b = x[1];
if a.signed_bits() + 2 * self.precision_out > F::NUM_BITS as usize
|| b.signed_bits() + (3 * self.precision_out) / 2 > F::NUM_BITS as usize
{
return EvalFailure::err_ub("input out of range");
}
Ok(vec![self.div_sqrt_public(a, b)])
}
fn eval_gap(&self, x: &[F]) -> F {
(self.div_sqrt_public(x[0], x[1]).abs() >> 47).max(F::from(4), true)
}
fn bounds(&self, bounds: Vec<FieldBounds<F>>) -> Vec<FieldBounds<F>> {
vec![self.div_sqrt_bounds(bounds[0], bounds[1])]
}
fn run(&self, vals: Vec<FieldValue<F>>) -> Vec<FieldValue<F>>
where
F: ActuallyUsedField,
{
if vals.len() != 2 {
panic!("div_sqrt requires two input2")
}
vec![self.div_sqrt(vals[0], vals[1])]
}
}
#[derive(Clone, Debug)]
pub struct Sqrt {
precision_in: usize,
}
impl Sqrt {
pub const fn new(precision_in: usize) -> Self {
if precision_in > DOUBLE_PRECISION_MANTISSA {
panic!("precision_in must be at most 52",);
}
Sqrt { precision_in }
}
pub fn sqrt<F: ActuallyUsedField>(&self, x: FieldValue<F>) -> FieldValue<F> {
let bounds = x.bounds();
DivSqrt::new(self.precision_in, self.precision_in)
.div_sqrt(x, x)
.with_bounds(self.sqrt_bounds(bounds))
}
fn sqrt_public<F: UsedField>(&self, x: F) -> F {
let x_signed = x.to_signed_number();
if x_signed > 0 {
let x_float = f64::from(x_signed) * 2f64.powi(-(self.precision_in as i32));
F::from(x_float.sqrt())
} else {
F::ZERO
}
}
fn sqrt_bounds<F: UsedField>(&self, bounds: FieldBounds<F>) -> FieldBounds<F> {
let (min, max) = bounds.min_and_max(true);
FieldBounds::new(
self.sqrt_public(min) - self.eval_gap(&[min]),
self.sqrt_public(max) + self.eval_gap(&[max]),
)
}
}
impl<F: UsedField> ArithmeticCircuit<F> for Sqrt {
fn eval(&self, x: Vec<F>) -> Result<Vec<F>, EvalFailure> {
if x.len() != 1 {
panic!("sqrt requires one input")
}
let x = x[0];
Ok(vec![self.sqrt_public(x)])
}
fn eval_gap(&self, x: &[F]) -> F {
(self.sqrt_public(x[0]) >> 47).max(F::from(4), true)
}
fn bounds(&self, bounds: Vec<FieldBounds<F>>) -> Vec<FieldBounds<F>> {
vec![self.sqrt_bounds(bounds[0])]
}
fn run(&self, vals: Vec<FieldValue<F>>) -> Vec<FieldValue<F>>
where
F: ActuallyUsedField,
{
if vals.len() != 1 {
panic!("sqrt requires one input")
}
vec![self.sqrt(vals[0])]
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
core::circuits::traits::arithmetic_circuit::tests::TestedArithmeticCircuit,
utils::field::ScalarField,
};
use rand::Rng;
impl TestedArithmeticCircuit<ScalarField> for DivSqrt {
fn gen_desc<R: Rng + ?Sized>(rng: &mut R) -> Self {
let mut precision_a = 52;
let mut precision_b = 52;
while rng.gen_bool(0.5) {
precision_a -= 1;
}
while rng.gen_bool(0.5) {
precision_b -= 1;
}
Self::new(precision_a as usize, precision_b as usize)
}
fn gen_n_inputs<R: Rng + ?Sized>(&self, _rng: &mut R) -> usize {
2
}
fn input_precisions(&self) -> Vec<usize> {
vec![self.precision_a, self.precision_b]
}
}
impl TestedArithmeticCircuit<ScalarField> for Sqrt {
fn gen_desc<R: Rng + ?Sized>(rng: &mut R) -> Self {
let mut precision = 52;
while rng.gen_bool(0.5) {
precision -= 1;
}
Self::new(precision as usize)
}
fn gen_n_inputs<R: Rng + ?Sized>(&self, _rng: &mut R) -> usize {
1
}
fn input_precisions(&self) -> Vec<usize> {
vec![self.precision_in]
}
}
#[test]
fn tested_div_sqrt() {
DivSqrt::test(16, 4)
}
#[test]
fn tested_sqrt() {
Sqrt::test(16, 4)
}
}