use std::fmt::Debug;
use num::BigUint;
use slop_air::AirBuilder;
use slop_algebra::PrimeField32;
use sp1_core_executor::events::ByteRecord;
use sp1_curves::params::{FieldParameters, Limbs};
use sp1_derive::AlignedBorrow;
use sp1_hypercube::air::SP1AirBuilder;
use sp1_primitives::polynomial::Polynomial;
use super::util_air::eval_field_operation;
use crate::air::WordAirBuilder;
#[derive(Debug, Clone, AlignedBorrow)]
#[repr(C)]
pub struct FieldDenCols<T, P: FieldParameters> {
pub result: Limbs<T, P::Limbs>,
pub(crate) carry: Limbs<T, P::Limbs>,
pub(crate) witness: Limbs<T, P::Witness>,
}
impl<F: PrimeField32, P: FieldParameters> FieldDenCols<F, P> {
pub fn populate(
&mut self,
record: &mut impl ByteRecord,
a: &BigUint,
b: &BigUint,
sign: bool,
) -> BigUint {
let p = P::modulus();
let minus_b_int = &p - b;
let b_signed = if sign { b.clone() } else { minus_b_int };
let denominator = (b_signed + 1u32) % &(p.clone());
let den_inv = denominator.modpow(&(&p - 2u32), &p);
let result = (a * &den_inv) % &p;
debug_assert_eq!(&den_inv * &denominator % &p, BigUint::from(1u32));
debug_assert!(result < p);
let equation_lhs = if sign { b * &result + &result } else { b * &result + a };
let equation_rhs = if sign { a.clone() } else { result.clone() };
let carry = (&equation_lhs - &equation_rhs) / &p;
debug_assert!(carry < p);
debug_assert_eq!(&carry * &p, &equation_lhs - &equation_rhs);
let mut p_a: Vec<u8> = a.to_bytes_le();
p_a.resize(P::NB_LIMBS, 0);
let mut p_b: Vec<u8> = b.to_bytes_le();
p_b.resize(P::NB_LIMBS, 0);
let mut p_p: Vec<u8> = p.to_bytes_le();
p_p.resize(P::MODULUS_LIMBS, 0);
let mut p_result: Vec<u8> = result.to_bytes_le();
p_result.resize(P::NB_LIMBS, 0);
let mut p_carry: Vec<u8> = carry.to_bytes_le();
p_carry.resize(P::NB_LIMBS, 0);
let mut p_vanishing_limbs = vec![0; P::NB_WITNESS_LIMBS + 1];
for i in 0..P::NB_LIMBS {
for j in 0..P::NB_LIMBS {
p_vanishing_limbs[i + j] += (p_b[i] as u16 * p_result[j] as u16) as i32;
}
}
for i in 0..P::NB_LIMBS {
for j in 0..P::MODULUS_LIMBS {
p_vanishing_limbs[i + j] -= (p_carry[i] as u16 * p_p[j] as u16) as i32;
}
}
if sign {
for i in 0..P::NB_LIMBS {
p_vanishing_limbs[i] += p_result[i] as i32;
p_vanishing_limbs[i] -= p_a[i] as i32;
}
} else {
for i in 0..P::NB_LIMBS {
p_vanishing_limbs[i] -= p_result[i] as i32;
p_vanishing_limbs[i] += p_a[i] as i32;
}
}
let len = P::NB_WITNESS_LIMBS + 1;
let mut pol_carry = p_vanishing_limbs[len - 1];
for i in (0..len - 1).rev() {
let ai = p_vanishing_limbs[i];
p_vanishing_limbs[i] = pol_carry;
pol_carry = ai + pol_carry * 256;
}
debug_assert_eq!(pol_carry, 0);
for i in 0..P::NB_LIMBS {
self.result[i] = F::from_canonical_u8(p_result[i]);
self.carry[i] = F::from_canonical_u8(p_carry[i]);
}
for i in 0..P::NB_WITNESS_LIMBS {
self.witness[i] =
F::from_canonical_u16((p_vanishing_limbs[i] + P::WITNESS_OFFSET as i32) as u16);
}
record.add_u8_range_checks_field(&self.result.0);
record.add_u8_range_checks_field(&self.carry.0);
record.add_u16_range_checks_field(&self.witness.0);
result
}
}
impl<V: Copy, P: FieldParameters> FieldDenCols<V, P>
where
Limbs<V, P::Limbs>: Copy,
{
#[allow(clippy::too_many_arguments)]
pub fn eval<AB: SP1AirBuilder<Var = V>>(
&self,
builder: &mut AB,
a: &Limbs<AB::Var, P::Limbs>,
b: &Limbs<AB::Var, P::Limbs>,
sign: bool,
is_real: impl Into<AB::Expr> + Clone,
) where
V: Into<AB::Expr>,
{
let p_a: Polynomial<<AB as AirBuilder>::Expr> = (*a).into();
let p_b: Polynomial<<AB as AirBuilder>::Expr> = (*b).into();
let p_result: Polynomial<<AB as AirBuilder>::Expr> = self.result.into();
let p_carry: Polynomial<<AB as AirBuilder>::Expr> = self.carry.into();
let p_equation_lhs =
if sign { &p_b * &p_result + &p_result } else { &p_b * &p_result + &p_a };
let p_equation_rhs = if sign { p_a } else { p_result };
let p_lhs_minus_rhs = &p_equation_lhs - &p_equation_rhs;
let p_limbs: Polynomial<<AB as AirBuilder>::Expr> =
Polynomial::from_iter(P::modulus_field_iter::<AB::F>().map(AB::Expr::from));
let p_vanishing: Polynomial<<AB as AirBuilder>::Expr> =
p_lhs_minus_rhs - &p_carry * &p_limbs;
let p_witness = self.witness.0.iter().into();
eval_field_operation::<AB, P>(builder, &p_vanishing, &p_witness);
builder.slice_range_check_u8(&self.result.0, is_real.clone());
builder.slice_range_check_u8(&self.carry.0, is_real.clone());
builder.slice_range_check_u16(&self.witness.0, is_real.clone());
}
}