use crate::{
params::{get_params, OptimizationType},
reduce::{bigint_to_basefield, limbs_to_bigint, Reducer},
AllocatedNonNativeFieldVar,
};
use snarkvm_fields::{FieldParameters, PrimeField};
use snarkvm_gadgets::{
fields::FpGadget,
traits::{
fields::FieldGadget,
utilities::{alloc::AllocGadget, boolean::Boolean},
},
};
use snarkvm_r1cs::{errors::SynthesisError, ConstraintSystem};
use num_bigint::BigUint;
use std::marker::PhantomData;
#[derive(Debug)]
pub struct AllocatedNonNativeFieldMulResultVar<TargetField: PrimeField, BaseField: PrimeField> {
pub limbs: Vec<FpGadget<BaseField>>,
pub prod_of_num_of_additions: BaseField,
#[doc(hidden)]
pub target_phantom: PhantomData<TargetField>,
}
impl<TargetField: PrimeField, BaseField: PrimeField> AllocatedNonNativeFieldMulResultVar<TargetField, BaseField> {
pub fn from_allocated_nonnative_field_gadget<CS: ConstraintSystem<BaseField>>(
cs: &mut CS,
src: &AllocatedNonNativeFieldVar<TargetField, BaseField>,
) -> Result<Self, SynthesisError> {
let field_parameters = get_params(
TargetField::size_in_bits(),
BaseField::size_in_bits(),
src.get_optimization_type(),
);
let mut limbs = src.limbs.clone();
limbs.reverse();
limbs.resize(2 * field_parameters.num_limbs - 1, FpGadget::<BaseField>::zero(cs)?);
limbs.reverse();
let prod_of_num_of_additions = src.num_of_additions_over_normal_form + &BaseField::one();
Ok(Self {
limbs,
prod_of_num_of_additions,
target_phantom: PhantomData,
})
}
pub fn value(&self) -> Result<TargetField, SynthesisError> {
let params = get_params(
TargetField::size_in_bits(),
BaseField::size_in_bits(),
self.get_optimization_type(),
);
let p_representations =
AllocatedNonNativeFieldVar::<TargetField, BaseField>::get_limbs_representations_from_big_integer(
&<TargetField as PrimeField>::Parameters::MODULUS,
self.get_optimization_type(),
)?;
let p_bigint = limbs_to_bigint(params.bits_per_limb, &p_representations);
let mut limbs_values = Vec::<BaseField>::new();
for limb in self.limbs.iter() {
limbs_values.push(limb.get_value().unwrap_or_default());
}
let value_bigint = limbs_to_bigint(params.bits_per_limb, &limbs_values);
let res = bigint_to_basefield::<TargetField>(&(value_bigint % p_bigint));
Ok(res)
}
pub fn reduce<CS: ConstraintSystem<BaseField>>(
&self,
cs: &mut CS,
) -> Result<AllocatedNonNativeFieldVar<TargetField, BaseField>, SynthesisError> {
let field_parameters = get_params(
TargetField::size_in_bits(),
BaseField::size_in_bits(),
self.get_optimization_type(),
);
let p_representations =
AllocatedNonNativeFieldVar::<TargetField, BaseField>::get_limbs_representations_from_big_integer(
&<TargetField as PrimeField>::Parameters::MODULUS,
self.get_optimization_type(),
)?;
let p_bigint = limbs_to_bigint(field_parameters.bits_per_limb, &p_representations);
let mut p_gadget_limbs = Vec::new();
for (i, limb) in p_representations.iter().enumerate() {
p_gadget_limbs.push(FpGadget::<BaseField>::alloc_constant(
cs.ns(|| format!("alloc_constant_{}", i)),
|| Ok(limb),
)?);
}
let p_gadget = AllocatedNonNativeFieldVar::<TargetField, BaseField> {
limbs: p_gadget_limbs,
num_of_additions_over_normal_form: BaseField::one(),
is_in_the_normal_form: false,
target_phantom: PhantomData,
};
let surfeit = overhead!(self.prod_of_num_of_additions + &BaseField::one()) + 1 + 1;
let k_bits = {
let mut result = Vec::new();
let mut limbs_values = Vec::<BaseField>::new();
for limb in self.limbs.iter() {
limbs_values.push(limb.get_value().unwrap_or_default());
}
let value_bigint = limbs_to_bigint(field_parameters.bits_per_limb, &limbs_values);
let mut k_cur = value_bigint / p_bigint;
let total_length = TargetField::size_in_bits() + surfeit;
for i in 0..total_length {
result.push(Boolean::alloc(cs.ns(|| format!("alloc_{}", i)), || {
Ok(&k_cur % 2u64 == BigUint::from(1u64))
})?);
k_cur /= 2u64;
}
result
};
let k_limbs = {
let zero = FpGadget::Constant(BaseField::zero());
let mut limbs = Vec::new();
let mut k_bits_cur = k_bits.clone();
for i in 0..field_parameters.num_limbs {
let this_limb_size = if i != field_parameters.num_limbs - 1 {
field_parameters.bits_per_limb
} else {
k_bits.len() - (field_parameters.num_limbs - 1) * field_parameters.bits_per_limb
};
let this_limb_bits = k_bits_cur[0..this_limb_size].to_vec();
k_bits_cur = k_bits_cur[this_limb_size..].to_vec();
let mut limb = zero.clone();
let mut cur = BaseField::one();
for (j, bit) in this_limb_bits.iter().enumerate() {
let mut temp =
FpGadget::<BaseField>::from_boolean(cs.ns(|| format!("from_boolean_{}_{}", i, j)), *bit)?;
temp = temp.mul_by_constant(cs.ns(|| format!("mul_by_constant_{}_{}", i, j)), &cur)?;
limb = limb.add(cs.ns(|| format!("add_{}_{}", i, j)), &temp)?;
cur.double_in_place();
}
limbs.push(limb);
}
limbs.reverse();
limbs
};
let k_gadget = AllocatedNonNativeFieldVar::<TargetField, BaseField> {
limbs: k_limbs,
num_of_additions_over_normal_form: self.prod_of_num_of_additions,
is_in_the_normal_form: false,
target_phantom: PhantomData,
};
let r_gadget = AllocatedNonNativeFieldVar::<TargetField, BaseField>::alloc(cs.ns(|| "r"), || self.value())?;
let params = get_params(
TargetField::size_in_bits(),
BaseField::size_in_bits(),
self.get_optimization_type(),
);
let mut prod_limbs = Vec::new();
let zero = FpGadget::<BaseField>::zero(cs.ns(|| "zero"))?;
for _ in 0..2 * params.num_limbs - 1 {
prod_limbs.push(zero.clone());
}
for i in 0..params.num_limbs {
for j in 0..params.num_limbs {
let temp = p_gadget.limbs[i].mul(cs.ns(|| format!("mul_{}_{}", i, j)), &k_gadget.limbs[j])?;
prod_limbs[i + j] = prod_limbs[i + j].add(cs.ns(|| format!("add_temp_{}_{}", i, j)), &temp)?;
}
}
let mut kp_plus_r_gadget = Self {
limbs: prod_limbs,
prod_of_num_of_additions: (p_gadget.num_of_additions_over_normal_form + &BaseField::one())
* &(k_gadget.num_of_additions_over_normal_form + &BaseField::one()),
target_phantom: PhantomData,
};
let kp_plus_r_limbs_len = kp_plus_r_gadget.limbs.len();
for (i, limb) in r_gadget.limbs.iter().rev().enumerate() {
kp_plus_r_gadget.limbs[kp_plus_r_limbs_len - 1 - i] =
kp_plus_r_gadget.limbs[kp_plus_r_limbs_len - 1 - i].add(cs.ns(|| format!("add_limb{}", i)), &limb)?;
}
Reducer::<TargetField, BaseField>::group_and_check_equality(
&mut cs.ns(|| "group_and_check_equality"),
surfeit,
2 * params.bits_per_limb,
params.bits_per_limb,
&self.limbs,
&kp_plus_r_gadget.limbs,
)?;
Ok(r_gadget)
}
pub fn add<CS: ConstraintSystem<BaseField>>(&self, cs: &mut CS, other: &Self) -> Result<Self, SynthesisError> {
assert_eq!(self.get_optimization_type(), other.get_optimization_type());
let mut new_limbs = Vec::new();
for (i, (l1, l2)) in self.limbs.iter().zip(other.limbs.iter()).enumerate() {
let new_limb = l1.add(cs.ns(|| format!("add_{}", i)), l2)?;
new_limbs.push(new_limb);
}
Ok(Self {
limbs: new_limbs,
prod_of_num_of_additions: self.prod_of_num_of_additions + &other.prod_of_num_of_additions,
target_phantom: PhantomData,
})
}
pub fn add_constant<CS: ConstraintSystem<BaseField>>(
&self,
cs: &mut CS,
other: &TargetField,
) -> Result<Self, SynthesisError> {
let mut other_limbs = AllocatedNonNativeFieldVar::<TargetField, BaseField>::get_limbs_representations(
other,
self.get_optimization_type(),
)?;
other_limbs.reverse();
let mut new_limbs = Vec::new();
for (i, limb) in self.limbs.iter().rev().enumerate() {
if i < other_limbs.len() {
new_limbs.push(limb.add_constant(cs.ns(|| format!("add_constant_{}", i)), &other_limbs[i])?);
} else {
new_limbs.push((*limb).clone());
}
}
new_limbs.reverse();
Ok(Self {
limbs: new_limbs,
prod_of_num_of_additions: self.prod_of_num_of_additions + &BaseField::one(),
target_phantom: PhantomData,
})
}
pub(crate) fn get_optimization_type(&self) -> OptimizationType {
OptimizationType::Weight
}
}