use super::{overhead, params::get_params, AllocatedNonNativeFieldVar};
use crate::{
alloc::AllocVar,
boolean::Boolean,
eq::EqGadget,
fields::{fp::FpVar, FieldVar},
R1CSVar,
};
use ark_ff::{biginteger::BigInteger, BitIteratorBE, One, PrimeField, Zero};
use ark_relations::{
ns,
r1cs::{ConstraintSystemRef, Result as R1CSResult},
};
use ark_std::{cmp::min, marker::PhantomData, vec, vec::Vec};
use num_bigint::BigUint;
use num_integer::Integer;
pub fn limbs_to_bigint<BaseField: PrimeField>(
bits_per_limb: usize,
limbs: &[BaseField],
) -> BigUint {
let mut val = BigUint::zero();
let mut big_cur = BigUint::one();
let two = BigUint::from(2u32);
for limb in limbs.iter().rev() {
let limb_repr = limb.into_bigint().to_bits_le();
let mut small_cur = big_cur.clone();
for limb_bit in limb_repr.iter() {
if *limb_bit {
val += &small_cur;
}
small_cur *= 2u32;
}
big_cur *= two.pow(bits_per_limb as u32);
}
val
}
pub fn bigint_to_basefield<BaseField: PrimeField>(bigint: &BigUint) -> BaseField {
let mut val = BaseField::zero();
let mut cur = BaseField::one();
let bytes = bigint.to_bytes_be();
let basefield_256 =
BaseField::from_bigint(<BaseField as PrimeField>::BigInt::from(256u64)).unwrap();
for byte in bytes.iter().rev() {
let bytes_basefield = BaseField::from(*byte as u128);
val += cur * bytes_basefield;
cur *= &basefield_256;
}
val
}
pub struct Reducer<TargetField: PrimeField, BaseField: PrimeField> {
pub target_phantom: PhantomData<TargetField>,
pub base_phantom: PhantomData<BaseField>,
}
impl<TargetField: PrimeField, BaseField: PrimeField> Reducer<TargetField, BaseField> {
#[tracing::instrument(target = "r1cs")]
pub fn limb_to_bits(
limb: &FpVar<BaseField>,
num_bits: usize,
) -> R1CSResult<Vec<Boolean<BaseField>>> {
let cs = limb.cs();
let num_bits = min(BaseField::MODULUS_BIT_SIZE as usize - 1, num_bits);
let mut bits_considered = Vec::with_capacity(num_bits);
let limb_value = limb.value().unwrap_or_default();
let num_bits_to_shave =
BaseField::BigInt::NUM_LIMBS * 64 - (BaseField::MODULUS_BIT_SIZE as usize);
for b in BitIteratorBE::new(limb_value.into_bigint())
.skip(num_bits_to_shave + (BaseField::MODULUS_BIT_SIZE as usize - num_bits))
{
bits_considered.push(b);
}
if cs == ConstraintSystemRef::None {
let mut bits = vec![];
for b in bits_considered {
bits.push(Boolean::<BaseField>::Constant(b));
}
Ok(bits)
} else {
let mut bits = vec![];
for b in bits_considered {
bits.push(Boolean::<BaseField>::new_witness(
ark_relations::ns!(cs, "bit"),
|| Ok(b),
)?);
}
let mut bit_sum = FpVar::<BaseField>::zero();
let mut coeff = BaseField::one();
for bit in bits.iter().rev() {
bit_sum +=
<FpVar<BaseField> as From<Boolean<BaseField>>>::from((*bit).clone()) * coeff;
coeff.double_in_place();
}
bit_sum.enforce_equal(limb)?;
Ok(bits)
}
}
#[tracing::instrument(target = "r1cs")]
pub fn reduce(elem: &mut AllocatedNonNativeFieldVar<TargetField, BaseField>) -> R1CSResult<()> {
let new_elem =
AllocatedNonNativeFieldVar::new_witness(ns!(elem.cs(), "normal_form"), || {
Ok(elem.value().unwrap_or_default())
})?;
elem.conditional_enforce_equal(&new_elem, &Boolean::TRUE)?;
*elem = new_elem;
Ok(())
}
#[tracing::instrument(target = "r1cs")]
pub fn post_add_reduce(
elem: &mut AllocatedNonNativeFieldVar<TargetField, BaseField>,
) -> R1CSResult<()> {
let params = get_params(
TargetField::MODULUS_BIT_SIZE as usize,
BaseField::MODULUS_BIT_SIZE as usize,
elem.get_optimization_type(),
);
let surfeit = overhead!(elem.num_of_additions_over_normal_form + BaseField::one()) + 1;
if BaseField::MODULUS_BIT_SIZE as usize > 2 * params.bits_per_limb + surfeit + 1 {
Ok(())
} else {
Self::reduce(elem)
}
}
#[tracing::instrument(target = "r1cs")]
pub fn pre_mul_reduce(
elem: &mut AllocatedNonNativeFieldVar<TargetField, BaseField>,
elem_other: &mut AllocatedNonNativeFieldVar<TargetField, BaseField>,
) -> R1CSResult<()> {
assert_eq!(
elem.get_optimization_type(),
elem_other.get_optimization_type()
);
let params = get_params(
TargetField::MODULUS_BIT_SIZE as usize,
BaseField::MODULUS_BIT_SIZE as usize,
elem.get_optimization_type(),
);
if 2 * params.bits_per_limb + ark_std::log2(params.num_limbs) as usize
> BaseField::MODULUS_BIT_SIZE as usize - 1
{
panic!("The current limb parameters do not support multiplication.");
}
loop {
let prod_of_num_of_additions = (elem.num_of_additions_over_normal_form
+ BaseField::one())
* (elem_other.num_of_additions_over_normal_form + BaseField::one());
let overhead_limb = overhead!(prod_of_num_of_additions.mul(
&BaseField::from_bigint(<BaseField as PrimeField>::BigInt::from(
(params.num_limbs) as u64
))
.unwrap()
));
let bits_per_mulresult_limb = 2 * (params.bits_per_limb + 1) + overhead_limb;
if bits_per_mulresult_limb < BaseField::MODULUS_BIT_SIZE as usize {
break;
}
if elem.num_of_additions_over_normal_form
>= elem_other.num_of_additions_over_normal_form
{
Self::reduce(elem)?;
} else {
Self::reduce(elem_other)?;
}
}
Ok(())
}
#[tracing::instrument(target = "r1cs")]
pub fn pre_eq_reduce(
elem: &mut AllocatedNonNativeFieldVar<TargetField, BaseField>,
) -> R1CSResult<()> {
if elem.is_in_the_normal_form {
return Ok(());
}
Self::reduce(elem)
}
#[tracing::instrument(target = "r1cs")]
pub fn group_and_check_equality(
surfeit: usize,
bits_per_limb: usize,
shift_per_limb: usize,
left: &[FpVar<BaseField>],
right: &[FpVar<BaseField>],
) -> R1CSResult<()> {
let cs = left.cs().or(right.cs());
let zero = FpVar::<BaseField>::zero();
let mut limb_pairs = Vec::<(FpVar<BaseField>, FpVar<BaseField>)>::new();
let num_limb_in_a_group = (BaseField::MODULUS_BIT_SIZE as usize
- 1
- surfeit
- 1
- 1
- 1
- (bits_per_limb - shift_per_limb))
/ shift_per_limb;
let shift_array = {
let mut array = Vec::new();
let mut cur = BaseField::one().into_bigint();
for _ in 0..num_limb_in_a_group {
array.push(BaseField::from_bigint(cur).unwrap());
cur.muln(shift_per_limb as u32);
}
array
};
for (left_limb, right_limb) in left.iter().zip(right.iter()).rev() {
limb_pairs.push((left_limb.clone(), right_limb.clone()));
}
let mut groupped_limb_pairs = Vec::<(FpVar<BaseField>, FpVar<BaseField>, usize)>::new();
for limb_pairs_in_a_group in limb_pairs.chunks(num_limb_in_a_group) {
let mut left_total_limb = zero.clone();
let mut right_total_limb = zero.clone();
for ((left_limb, right_limb), shift) in
limb_pairs_in_a_group.iter().zip(shift_array.iter())
{
left_total_limb += &(left_limb * *shift);
right_total_limb += &(right_limb * *shift);
}
groupped_limb_pairs.push((
left_total_limb,
right_total_limb,
limb_pairs_in_a_group.len(),
));
}
let mut carry_in = zero;
let mut carry_in_value = BaseField::zero();
let mut accumulated_extra = BigUint::zero();
for (group_id, (left_total_limb, right_total_limb, num_limb_in_this_group)) in
groupped_limb_pairs.iter().enumerate()
{
let mut pad_limb_repr: <BaseField as PrimeField>::BigInt =
BaseField::one().into_bigint();
pad_limb_repr.muln(
(surfeit
+ (bits_per_limb - shift_per_limb)
+ shift_per_limb * num_limb_in_this_group
+ 1
+ 1) as u32,
);
let pad_limb = BaseField::from_bigint(pad_limb_repr).unwrap();
let left_total_limb_value = left_total_limb.value().unwrap_or_default();
let right_total_limb_value = right_total_limb.value().unwrap_or_default();
let mut carry_value =
left_total_limb_value + carry_in_value + pad_limb - right_total_limb_value;
let mut carry_repr = carry_value.into_bigint();
carry_repr.divn((shift_per_limb * num_limb_in_this_group) as u32);
carry_value = BaseField::from_bigint(carry_repr).unwrap();
let carry = FpVar::<BaseField>::new_witness(cs.clone(), || Ok(carry_value))?;
accumulated_extra += limbs_to_bigint(bits_per_limb, &[pad_limb]);
let (new_accumulated_extra, remainder) = accumulated_extra.div_rem(
&BigUint::from(2u64).pow((shift_per_limb * num_limb_in_this_group) as u32),
);
let remainder_limb = bigint_to_basefield::<BaseField>(&remainder);
let eqn_left = left_total_limb + pad_limb + &carry_in - right_total_limb;
let eqn_right = &carry
* BaseField::from(2u64).pow(&[(shift_per_limb * num_limb_in_this_group) as u64])
+ remainder_limb;
eqn_left.conditional_enforce_equal(&eqn_right, &Boolean::<BaseField>::TRUE)?;
accumulated_extra = new_accumulated_extra;
carry_in = carry.clone();
carry_in_value = carry_value;
if group_id == groupped_limb_pairs.len() - 1 {
carry.enforce_equal(&FpVar::<BaseField>::Constant(bigint_to_basefield(
&accumulated_extra,
)))?;
} else {
Reducer::<TargetField, BaseField>::limb_to_bits(&carry, surfeit + bits_per_limb)?;
}
}
Ok(())
}
}