use crate::{
core::{
actually_used_field::ActuallyUsedField,
bounds::FieldBounds,
circuits::{
boolean::utils::{icpot_unsigned, shift_right_round_towards_zero, CircuitType},
traits::arithmetic_circuit::ArithmeticCircuit,
},
expressions::{expr::EvalFailure, field_expr::div_bounds},
global_value::value::FieldValue,
},
traits::{GreaterEqual, Select},
utils::{number::Number, used_field::UsedField},
};
use std::marker::PhantomData;
#[derive(Clone, Debug)]
pub struct FastDivide<F: UsedField> {
max_e_a: usize,
max_e_b: usize,
marker: PhantomData<F>,
}
impl<F: UsedField> FastDivide<F> {
fn test_integrity(max_e_a: usize, max_e_b: usize) -> bool {
2 * (max_e_a + max_e_b).max(7) + 2 <= F::CAPACITY as usize
}
#[allow(unused)]
pub fn new(max_e_a: usize, max_e_b: usize) -> Self {
if !Self::test_integrity(max_e_a, max_e_b) {
panic!("max_e_a:{max_e_a} and/or max_e_b:{max_e_b} are too high")
}
FastDivide {
max_e_a,
max_e_b,
marker: PhantomData,
}
}
}
impl<F: ActuallyUsedField> FastDivide<F> {
pub fn inv_approx(&self, x: FieldValue<F>, eta: usize, precision_out: usize) -> FieldValue<F> {
let bounds = x.bounds();
let seventeen_inv = Number::power_of_two(eta) / 17;
let x_inv_init = (-32 * &seventeen_inv * x
+ 48 * &seventeen_inv * Number::power_of_two(eta))
>> (eta + 1);
let close_to_one_init = x * x_inv_init;
let close_to_one_init = close_to_one_init >> (eta - 1);
let (close_to_one_bounds, x_inv_bounds) = if bounds.signed_min().is_le_zero()
|| (bounds.signed_max() - F::power_of_two(eta)).is_ge_zero()
{
(FieldBounds::All, FieldBounds::All)
} else {
(
FieldBounds::new(
F::power_of_two(eta) - F::from(3 * &seventeen_inv),
F::power_of_two(eta) + F::from(3 * &seventeen_inv),
),
FieldBounds::new(F::power_of_two(eta - 2), F::power_of_two(eta) - F::ONE),
)
};
let close_to_one_init = close_to_one_init.with_bounds(close_to_one_bounds);
let n_iter = (((precision_out + 1) as f64 / 17f64.log2()).log2().ceil() as i32).max(0);
fn goldschmidt_iter<F: ActuallyUsedField>(
x_inv: FieldValue<F>,
close_to_one: FieldValue<F>,
eta: usize,
i: i32,
close_to_one_bounds: FieldBounds<F>,
x_inv_bounds: FieldBounds<F>,
) -> (FieldValue<F>, FieldValue<F>) {
if i == 0 {
(x_inv, close_to_one)
} else {
let update = -close_to_one + F::power_of_two(eta + 1);
let new_x_inv = (x_inv * update) >> eta;
let new_x_inv = new_x_inv.with_bounds(x_inv_bounds);
let new_close_to_one = (close_to_one * update) >> eta;
let new_close_to_one = new_close_to_one.with_bounds(close_to_one_bounds);
goldschmidt_iter(
new_x_inv,
new_close_to_one,
eta,
i - 1,
close_to_one_bounds,
x_inv_bounds,
)
}
}
let (x_inv, _) = goldschmidt_iter(
x_inv_init,
close_to_one_init,
eta,
n_iter,
close_to_one_bounds,
x_inv_bounds,
);
x_inv
}
fn divide_unsigned_bitnums(&self, a: FieldValue<F>, b: FieldValue<F>) -> FieldValue<F> {
let a_bounds = a.bounds();
let b_bounds = b.bounds();
let b_max = b_bounds.unsigned_max();
let e_a = a_bounds.unsigned_bin_size().min(self.max_e_a);
let e_b = b_max.unsigned_bits().min(self.max_e_b);
let eta = (e_a + e_b).max(7);
let (b_icpot_bounds, div_bounds, res_bounds, diff_bounds, icpot) =
if a_bounds.unsigned_bin_size() > self.max_e_a
|| b_bounds.unsigned_bin_size() > self.max_e_b
|| b_bounds.signed_min().is_le_zero()
{
let icpot = icpot_unsigned(
b,
b_bounds.unsigned_bin_size().min(e_b),
CircuitType::default(),
)
.0;
(
FieldBounds::All,
FieldBounds::All,
FieldBounds::All,
FieldBounds::All,
icpot,
)
} else {
let b_icpot_bounds =
FieldBounds::new(F::power_of_two(e_b - 1), F::power_of_two(e_b) - F::ONE);
let (div_min, div_max) = div_bounds(a_bounds, b_bounds).min_and_max(false);
let res_bounds = FieldBounds::new(
(div_min - F::ONE) * F::power_of_two(eta + e_b - 1),
(div_max + F::ONE) * F::power_of_two(eta + e_b - 1),
);
let icpot = icpot_unsigned(b, e_b, CircuitType::default()).0;
(
b_icpot_bounds,
FieldBounds::new(div_min, div_max),
res_bounds,
FieldBounds::new(F::ZERO, F::from(2) * b_max),
icpot,
)
};
let b_icpot = b * icpot;
let b_icpot = b_icpot.with_bounds(b_icpot_bounds);
let b_norm = b_icpot * F::power_of_two(eta - e_b);
let b_inv = self.inv_approx(b_norm, eta, e_a);
let a_icpot = a * icpot;
let res = (a_icpot * b_inv).with_bounds(res_bounds);
let res_before_correction = res >> (eta + e_b - 1);
let prod = res_before_correction * b;
let diff = (a - prod).with_bounds(diff_bounds);
let is_incorrect = diff.ge(b);
let corrected_res = FieldValue::<F>::from(is_incorrect) + res_before_correction;
corrected_res.with_bounds(div_bounds)
}
}
impl<F: UsedField> ArithmeticCircuit<F> for FastDivide<F> {
fn eval(&self, x: Vec<F>) -> Result<Vec<F>, EvalFailure> {
if x.len() != 2 {
panic!("FastDivide requires two numbers")
}
if x[1] == F::ZERO {
EvalFailure::err_ub("division by zero")?;
}
if x[0].unsigned_bits() > self.max_e_a {
EvalFailure::err_ub("x[0] too big")?;
}
if x[1].unsigned_bits() > self.max_e_b {
EvalFailure::err_ub("x[1] too big")?;
}
Ok(vec![x[0].unsigned_euclidean_division(x[1])])
}
fn bounds(&self, _bounds: Vec<FieldBounds<F>>) -> Vec<FieldBounds<F>> {
vec![FieldBounds::All]
}
fn run(&self, vals: Vec<FieldValue<F>>) -> Vec<FieldValue<F>>
where
F: ActuallyUsedField,
{
let a = vals[0];
let b = vals[1];
let e_a = a.bounds().unsigned_bin_size().min(self.max_e_a);
let b_bounds = b.bounds();
let e_b = b_bounds.unsigned_bin_size().min(self.max_e_b);
let (b_min_bound, b_max_bound) = b_bounds.min_and_max(false);
let is_neg_b = b.lt(0);
let res = if b_min_bound == b_max_bound {
let b_num = b_min_bound;
if b_num == F::ZERO {
return vec![0.into()];
}
let eta = e_a + e_b;
let neg_a = -a;
let offset = is_neg_b.select(neg_a, a);
let b_inv = F::power_of_two(eta).unsigned_euclidean_division(b_num);
let prod = a * b_inv;
let res = prod + offset;
shift_right_round_towards_zero(res, eta)
} else {
let abs_a = a.abs();
let abs_b = b.abs();
let abs_res = self.divide_unsigned_bitnums(abs_a, abs_b);
let sign_res = a.sign() * b.sign();
sign_res * abs_res
};
vec![res]
}
}
#[cfg(test)]
mod tests {
use crate::{
core::{
actually_used_field::ActuallyUsedField,
circuits::{
arithmetic::fast_divide::FastDivide,
traits::arithmetic_circuit::{tests::TestedArithmeticCircuit, ArithmeticCircuit},
},
expressions::{
expr::EvalValue,
field_expr::{FieldExpr, InputInfo},
},
ir::IntermediateRepresentation,
ir_builder::{ExprStore, IRBuilder},
},
utils::{field::ScalarField, number::Number, used_field::UsedField},
};
use ff::{Field, PrimeField};
use rand::Rng;
use rustc_hash::FxHashMap;
use std::rc::Rc;
fn test_interesting_vals_for_division<R: Rng + ?Sized>(
rng: &mut R,
div: ScalarField,
numerator_input_info: Rc<InputInfo<ScalarField>>,
ctrl_ir: &IntermediateRepresentation,
test_ir: &IntermediateRepresentation,
) {
if div == ScalarField::ZERO {
return;
}
let (num_min, num_max) = (numerator_input_info.min, numerator_input_info.max);
let (q_min, q_max) = (
num_min.unsigned_euclidean_division(div),
num_max.unsigned_euclidean_division(div),
);
if q_min == q_max {
return;
}
let (q_min, q_max) = if q_min > q_max {
(q_max, q_min)
} else {
(q_min, q_max)
};
let ref_q = ScalarField::gen_inclusive_range(rng, q_min, q_max);
for q in [ref_q + ScalarField::ONE, ref_q] {
for num in [
q * div - ScalarField::ONE,
q * div,
q * div + ScalarField::ONE,
] {
if num_min > num || num > num_max {
continue;
}
let mut input_vals = FxHashMap::<usize, _>::default();
input_vals.insert(0, EvalValue::Scalar(num));
input_vals.insert(1, EvalValue::Scalar(div));
IntermediateRepresentation::test_eq_with_vals(
rng,
ctrl_ir,
test_ir,
&mut input_vals,
);
}
}
}
const MAX_DIVISION_SIZE: usize = (ScalarField::CAPACITY as usize - 2) / 4;
#[test]
fn divide_test() {
let rng = &mut crate::utils::test_rng::get();
for magnitude in [1, 4, 16, MAX_DIVISION_SIZE] {
for magnitude_1 in [1, 4, 16, MAX_DIVISION_SIZE] {
let limit = if magnitude + magnitude_1 < 100 { 4 } else { 1 };
for _ in 0..limit {
let input_info_0 = {
let lower = Number::from(0);
let upper = Number::power_of_two(magnitude);
InputInfo::generate(rng, &lower, &upper)
};
let input_info_1 = {
let lower = Number::from(1);
let upper = Number::power_of_two(magnitude_1);
InputInfo::generate(rng, &lower, &upper)
};
let mut ctrl_ir_builder = IRBuilder::new(true);
let e0 = ctrl_ir_builder.push_field(FieldExpr::Input(0, input_info_0.clone()));
let e1 = ctrl_ir_builder.push_field(FieldExpr::Input(1, input_info_1.clone()));
let mut test_ir_builder = ctrl_ir_builder.clone();
let circuit = FastDivide::<ScalarField>::new(magnitude, magnitude_1);
let test_output = circuit.run_usize(&[e0, e1], &mut test_ir_builder);
let test_ir = test_ir_builder.into_ir(test_output);
let ctrl_output =
ctrl_ir_builder.push_field(FieldExpr::<ScalarField, _>::Div(e0, e1));
let ctrl_ir = ctrl_ir_builder.into_ir(vec![ctrl_output]);
IntermediateRepresentation::test_eq(rng, &ctrl_ir, &test_ir, 1);
let div =
ScalarField::gen_inclusive_range(rng, input_info_1.min, input_info_1.max);
test_interesting_vals_for_division(rng, div, input_info_0, &ctrl_ir, &test_ir)
}
}
}
}
#[test]
fn divide_const_test() {
let rng = &mut crate::utils::test_rng::get();
for _ in 0..64 {
let lower = Number::from(0);
let upper = Number::power_of_two(MAX_DIVISION_SIZE);
let input_info_0 = InputInfo::generate(rng, &lower, &upper);
let input_info_1 = InputInfo::generate(rng, &(lower + 1), &upper);
let mut ctrl_ir_builder = IRBuilder::new(true);
let e0 = ctrl_ir_builder.push_field(FieldExpr::Input(0, input_info_0.clone()));
let my_val = ScalarField::gen_inclusive_range(rng, input_info_1.min, input_info_1.max);
let e1 = ctrl_ir_builder.push_field(FieldExpr::Val(my_val));
let mut test_ir_builder = ctrl_ir_builder.clone();
let circuit = FastDivide::<ScalarField>::new(MAX_DIVISION_SIZE, MAX_DIVISION_SIZE);
let test_output = circuit.run_usize(&[e0, e1], &mut test_ir_builder);
let test_ir = test_ir_builder.into_ir(test_output);
let ctrl_output = ctrl_ir_builder.push_field(FieldExpr::<ScalarField, _>::Div(e0, e1));
let ctrl_ir = ctrl_ir_builder.into_ir(vec![ctrl_output]);
IntermediateRepresentation::test_eq(rng, &ctrl_ir, &test_ir, 4);
test_interesting_vals_for_division(rng, my_val, input_info_0, &ctrl_ir, &test_ir)
}
}
impl<F: ActuallyUsedField> TestedArithmeticCircuit<F> for FastDivide<F> {
fn gen_desc<R: Rng + ?Sized>(rng: &mut R) -> Self {
let mut max_e_a = 1 << 16;
let mut max_e_b = 1 << 16;
while !Self::test_integrity(max_e_a, max_e_b) {
max_e_a = (rng.next_u32() % ScalarField::NUM_BITS) as usize;
max_e_b = (rng.next_u32() % ScalarField::NUM_BITS) as usize;
}
FastDivide::new(max_e_a, max_e_b)
}
fn gen_n_inputs<R: Rng + ?Sized>(&self, _rng: &mut R) -> usize {
2
}
}
#[test]
fn tested() {
FastDivide::<ScalarField>::test(16, 8)
}
}