use std::ops::{Mul, Rem};
use midnight_proofs::{
circuit::Value,
plonk::{Advice, Column, Expression, VirtualCells},
poly::Rotation,
};
use num_bigint::{BigInt as BI, BigUint, ToBigInt};
use num_integer::Integer;
use num_traits::{One, Signed, Zero};
use crate::{
field::decomposition::cpu_utils::compute_optimal_limb_sizes, utils::util::bigint_to_fe,
CircuitField,
};
pub fn urem(value: &BI, modulus: &BI) -> BI {
let mut output = value.rem(modulus);
if output.is_negative() {
output += modulus;
}
output
}
pub fn signed_mod(value: &BI, modulus: &BI) -> BI {
let r = urem(value, modulus);
if &r * 2 > *modulus {
r - modulus
} else {
r
}
}
pub fn signed_repr<K: CircuitField>() -> impl Fn(BI) -> BI {
let m = K::modulus().to_bigint().unwrap();
move |v: BI| signed_mod(&v, &m)
}
pub fn ceil_log2(value: &BI) -> u32 {
BI::bits(&(value - BI::one())) as u32
}
fn next_cheapest_power_of_two(nb_parallel_range_checks: usize, max_bit_len: u32, x: &BI) -> BI {
let mut solutions = std::collections::HashMap::new();
let cost = |solutions: &mut std::collections::HashMap<i32, Vec<Vec<usize>>>, bound: u32| {
compute_optimal_limb_sizes(
solutions,
nb_parallel_range_checks,
max_bit_len as usize,
bound as i32,
)
.len()
};
let base_log = ceil_log2(x);
let mut best_log = base_log;
let mut best_cost = cost(&mut solutions, base_log);
for i in 1..=128 {
let c = cost(&mut solutions, base_log + i);
if c < best_cost {
best_cost = c;
best_log = base_log + i;
}
}
BI::pow(&BI::from(2), best_log)
}
pub fn bi_to_limbs(nb_limbs: u32, base: &BI, value: &BI) -> Vec<BI> {
if value.is_negative() {
panic!("bi_to_limbs: value must be greater than or equal to zero");
}
let mut output = vec![];
let mut q = (*value).clone();
let mut r;
while output.len() < nb_limbs as usize {
(q, r) = q.div_rem(base);
output.push(r.clone());
}
if !BI::is_zero(&q) {
panic!(
"bi_to_limbs: {} cannot be expressed in base {} with {} limbs",
value, base, nb_limbs
)
};
output
}
pub fn bi_from_limbs(base: &BI, limbs: &[BI]) -> BI {
limbs.iter().rev().fold(BI::zero(), |acc, limb| acc * base + limb)
}
pub fn big_to_limbs(nb_limbs: u32, base: &BigUint, value: &BigUint) -> Vec<BigUint> {
let mut output = vec![];
let mut q = (*value).clone();
let mut r;
while output.len() < nb_limbs as usize {
(q, r) = q.div_rem(base);
output.push(r.clone());
}
if !BigUint::is_zero(&q) {
panic!(
"big_to_limbs: {} cannot be expressed in base {} with {} limbs",
value, base, nb_limbs
)
};
output
}
pub fn big_from_limbs(base: &BigUint, limbs: &[BigUint]) -> BigUint {
limbs.iter().rev().fold(BigUint::zero(), |acc, limb| acc * base + limb)
}
pub fn sum_bigints(coeffs: &[BI], values: &[BI]) -> BI {
debug_assert!(coeffs.len() == values.len());
values.iter().zip(coeffs.iter()).map(|(v, b)| b * v).sum::<BI>()
}
pub fn sum_exprs<F: CircuitField>(coeffs: &[BI], exprs: &[Expression<F>]) -> Expression<F> {
debug_assert!(coeffs.len() == exprs.len());
exprs
.iter()
.zip(coeffs.iter())
.map(|(v, b)| Expression::Constant(bigint_to_fe::<F>(b)) * v.clone())
.fold(Expression::from(0), |acc, e| acc + e)
}
pub fn pair_wise_prod<T: Mul<Output = T> + Clone>(v: &[T], w: &[T]) -> Vec<T> {
v.iter()
.flat_map(|vi| w.iter().map(|wj| vi.clone() * wj.clone()).collect::<Vec<_>>())
.collect::<Vec<_>>()
}
pub fn get_advice_vec<F: CircuitField>(
meta: &mut VirtualCells<'_, F>,
columns: &[Column<Advice>],
rotation: Rotation,
) -> Vec<Expression<F>> {
columns.iter().map(|&col| meta.query_advice(col, rotation)).collect::<Vec<_>>()
}
pub fn get_identity_auxiliary_bounds<F, K>(
equation_name: &str,
moduli: &[BI],
expr_bounds: (BI, BI),
expr_mj_bounds: &[(BI, BI)],
nb_parallel_range_checks: usize,
max_bit_len: u32,
) -> ((BI, BI), Vec<(BI, BI)>)
where
F: CircuitField,
K: CircuitField,
{
let m = &K::modulus().to_bigint().unwrap();
let native_modulus = &F::modulus().to_bigint().unwrap();
let k_min = expr_bounds.0.div_ceil(m);
let k_max = expr_bounds.1.div_floor(m);
let u_max = next_cheapest_power_of_two(
nb_parallel_range_checks,
max_bit_len,
&(&k_max - &k_min + BI::one()),
);
let lower_bound = expr_bounds.0 - (&u_max + &k_min) * m;
let upper_bound = expr_bounds.1 - &k_min * m;
let mut necessary_moduli = vec![];
let mut lcm = native_modulus.clone();
for mj in moduli.iter() {
if lcm > -&lower_bound && lcm > upper_bound {
break;
}
lcm = lcm.lcm(mj);
necessary_moduli.push(mj.clone());
}
if lcm <= -lower_bound || lcm <= upper_bound {
panic!("lcm-threshold not reached, consider using extra auxiliari moduli")
}
let v_bounds: Vec<_> = necessary_moduli
.iter()
.zip(expr_mj_bounds.iter())
.map(|(mj, (expr_mj_min, expr_mj_max))| {
let k_min_m_mod_mj = urem(&(&k_min * m), mj);
let lj_min = (expr_mj_min - &u_max * urem(m, mj) - &k_min_m_mod_mj).div_ceil(mj);
let lj_max = (expr_mj_max - &k_min_m_mod_mj).div_floor(mj);
let vj_max = next_cheapest_power_of_two(
nb_parallel_range_checks,
max_bit_len,
&(&lj_max - &lj_min + BI::one()),
);
let lower_bound =
expr_mj_min - &u_max * urem(m, mj) - &k_min_m_mod_mj - (&vj_max + &lj_min) * mj;
let upper_bound = expr_mj_max - &k_min_m_mod_mj - &lj_min * mj;
if *native_modulus <= -lower_bound || *native_modulus <= upper_bound {
panic!(
"Equation {} modulo {} may wrap-around the native modulus",
equation_name, mj
)
}
(lj_min, vj_max)
})
.collect();
((k_min, u_max), v_bounds)
}
pub fn compute_u(m: &BI, expr: &BI, u_bounds: (&BI, &BI), _assertions: Value<bool>) -> BI {
let (k_min, _u_max) = u_bounds;
let (u_plus_k_min, _r) = expr.div_rem(m);
#[cfg(not(test))]
_assertions.map(|b| {
if b {
let u = u_plus_k_min.clone() - k_min;
debug_assert!(BI::is_zero(&_r));
debug_assert!(!BI::is_negative(&u));
debug_assert!(&u < _u_max);
}
});
u_plus_k_min - k_min
}
pub fn compute_vj(
m: &BI,
mj: &BI,
expr_mj: &BI,
u: &BI,
k_min: &BI,
vj_bounds: (&BI, &BI),
_assertions: Value<bool>,
) -> BI {
let (lj_min, _vj_max) = vj_bounds;
let (vj_plus_lj_min, _r) = (expr_mj - u * urem(m, mj) - urem(&(k_min * m), mj)).div_rem(mj);
#[cfg(not(test))]
_assertions.map(|b| {
if b {
let vj = &vj_plus_lj_min - lj_min;
debug_assert!(BI::is_zero(&_r));
debug_assert!(!BI::is_negative(&vj));
debug_assert!(&vj < _vj_max);
}
});
&vj_plus_lj_min - lj_min
}