use std::{cmp::max, iter};
use halo2_base::{
gates::{GateInstructions, RangeInstructions},
utils::{decompose_bigint, BigPrimeField},
AssignedValue, Context,
QuantumCell::{Constant, Existing, Witness},
};
use num_bigint::BigInt;
use num_integer::Integer;
use num_traits::{One, Signed};
use super::{check_carry_to_zero, CRTInteger, OverflowInteger, ProperCrtUint, ProperUint};
pub fn crt<F: BigPrimeField>(
range: &impl RangeInstructions<F>,
ctx: &mut Context<F>,
a: CRTInteger<F>,
k_bits: usize, modulus: &BigInt,
mod_vec: &[F],
mod_native: F,
limb_bits: usize,
limb_bases: &[F],
limb_base_big: &BigInt,
) -> ProperCrtUint<F> {
let n = limb_bits;
let k = a.truncation.limbs.len();
let trunc_len = n * k;
debug_assert!(a.value.bits() as usize <= n * k - 1 + (F::NUM_BITS as usize) - 2);
let quot_max_bits = trunc_len - 1 + (F::NUM_BITS as usize) - 1 - (modulus.bits() as usize);
debug_assert!(quot_max_bits < trunc_len);
let quot_last_limb_bits = quot_max_bits - n * (k - 1);
let out_max_bits = modulus.bits() as usize;
let out_last_limb_bits = out_max_bits - n * (k - 1);
let (quot_val, out_val) = a.value.div_mod_floor(modulus);
debug_assert!(out_val < (BigInt::one() << (n * k)));
debug_assert!(quot_val.abs() < (BigInt::one() << quot_max_bits));
let out_vec = decompose_bigint::<F>(&out_val, k, n);
let quot_vec = decompose_bigint::<F>("_val, k, n);
assert_eq!(mod_vec.len(), k);
let mut quot_assigned: Vec<AssignedValue<F>> = Vec::with_capacity(k);
let mut out_assigned: Vec<AssignedValue<F>> = Vec::with_capacity(k);
let mut check_assigned: Vec<AssignedValue<F>> = Vec::with_capacity(k);
for (i, ((a_limb, quot_v), out_v)) in
a.truncation.limbs.into_iter().zip(quot_vec).zip(out_vec).enumerate()
{
let (prod, new_quot_cell) = range.gate().inner_product_left_last(
ctx,
quot_assigned.iter().map(|a| Existing(*a)).chain(iter::once(Witness(quot_v))),
mod_vec[..=i].iter().rev().map(|c| Constant(*c)),
);
let temp1 = *prod.value() - a_limb.value();
let check_val = temp1 + out_v;
ctx.assign_region(
[
Constant(-F::ONE),
Existing(a_limb),
Witness(temp1),
Constant(F::ONE),
Witness(out_v),
Witness(check_val),
],
[-1, 2], );
let check_cell = ctx.last().unwrap();
let out_cell = ctx.get(-2);
quot_assigned.push(new_quot_cell);
out_assigned.push(out_cell);
check_assigned.push(check_cell);
}
for (out_index, out_cell) in out_assigned.iter().enumerate() {
let limb_bits = if out_index == k - 1 { out_last_limb_bits } else { n };
range.range_check(ctx, *out_cell, limb_bits);
}
for (q_index, quot_cell) in quot_assigned.iter().enumerate() {
let limb_bits = if q_index == k - 1 { quot_last_limb_bits } else { n };
let limb_base =
if q_index == k - 1 { range.gate().pow_of_two()[limb_bits] } else { limb_bases[1] };
let quot_shift = range.gate().add(ctx, *quot_cell, Constant(limb_base));
range.range_check(ctx, quot_shift, limb_bits + 1);
}
let check_overflow_int = OverflowInteger::new(
check_assigned,
max(max(limb_bits, a.truncation.max_limb_bits) + 1, 2 * n + k_bits),
);
check_carry_to_zero::truncate::<F>(
range,
ctx,
check_overflow_int,
limb_bits,
limb_bases[1],
limb_base_big,
);
let quot_native =
OverflowInteger::evaluate_native(ctx, range.gate(), quot_assigned, limb_bases);
let out_native =
OverflowInteger::evaluate_native(ctx, range.gate(), out_assigned.clone(), limb_bases);
ctx.assign_region(
[Constant(mod_native), Existing(quot_native), Existing(a.native)],
[-1], );
ProperCrtUint(CRTInteger::new(
ProperUint(out_assigned).into_overflow(limb_bits),
out_native,
out_val,
))
}