use std::fmt::Debug;
use num::{BigUint, Zero};
use p3_air::AirBuilder;
use p3_field::PrimeField32;
use sp1_derive::AlignedBorrow;
use super::params::{FieldParameters, Limbs};
use super::util::{compute_root_quotient_and_shift, split_u16_limbs_to_u8_limbs};
use super::util_air::eval_field_operation;
use crate::air::Polynomial;
use crate::air::SP1AirBuilder;
use crate::bytes::event::ByteRecord;
use typenum::Unsigned;
#[derive(PartialEq, Copy, Clone, Debug)]
pub enum FieldOperation {
Add,
Mul,
Sub,
Div,
}
#[derive(Debug, Clone, AlignedBorrow)]
#[repr(C)]
pub struct FieldOpCols<T, P: FieldParameters> {
pub result: Limbs<T, P::Limbs>,
pub(crate) carry: Limbs<T, P::Limbs>,
pub(crate) witness_low: Limbs<T, P::Witness>,
pub(crate) witness_high: Limbs<T, P::Witness>,
}
impl<F: PrimeField32, P: FieldParameters> FieldOpCols<F, P> {
pub fn populate_carry_and_witness(
&mut self,
a: &BigUint,
b: &BigUint,
op: FieldOperation,
modulus: &BigUint,
) -> BigUint {
let p_a: Polynomial<F> = P::to_limbs_field::<F, _>(a).into();
let p_b: Polynomial<F> = P::to_limbs_field::<F, _>(b).into();
let (result, carry) = match op {
FieldOperation::Add => ((a + b) % modulus, (a + b - (a + b) % modulus) / modulus),
FieldOperation::Mul => ((a * b) % modulus, (a * b - (a * b) % modulus) / modulus),
FieldOperation::Sub | FieldOperation::Div => unreachable!(),
};
debug_assert!(&result < modulus);
debug_assert!(&carry < modulus);
match op {
FieldOperation::Add => debug_assert_eq!(&carry * modulus, a + b - &result),
FieldOperation::Mul => debug_assert_eq!(&carry * modulus, a * b - &result),
FieldOperation::Sub | FieldOperation::Div => unreachable!(),
}
let p_modulus_limbs = modulus
.to_bytes_le()
.iter()
.map(|x| F::from_canonical_u8(*x))
.collect::<Vec<F>>();
let p_modulus: Polynomial<F> = p_modulus_limbs.iter().into();
let p_result: Polynomial<F> = P::to_limbs_field::<F, _>(&result).into();
let p_carry: Polynomial<F> = P::to_limbs_field::<F, _>(&carry).into();
let p_op = match op {
FieldOperation::Add => &p_a + &p_b,
FieldOperation::Mul => &p_a * &p_b,
FieldOperation::Sub | FieldOperation::Div => unreachable!(),
};
let p_vanishing: Polynomial<F> = &p_op - &p_result - &p_carry * &p_modulus;
let p_witness = compute_root_quotient_and_shift(
&p_vanishing,
P::WITNESS_OFFSET,
P::NB_BITS_PER_LIMB as u32,
P::NB_WITNESS_LIMBS,
);
let (mut p_witness_low, mut p_witness_high) = split_u16_limbs_to_u8_limbs(&p_witness);
self.result = p_result.into();
self.carry = p_carry.into();
p_witness_low.resize(P::Witness::USIZE, F::zero());
p_witness_high.resize(P::Witness::USIZE, F::zero());
self.witness_low = Limbs(p_witness_low.try_into().unwrap());
self.witness_high = Limbs(p_witness_high.try_into().unwrap());
result
}
#[allow(clippy::too_many_arguments)]
pub fn populate_with_modulus(
&mut self,
record: &mut impl ByteRecord,
shard: u32,
channel: u8,
a: &BigUint,
b: &BigUint,
modulus: &BigUint,
op: FieldOperation,
) -> BigUint {
if b == &BigUint::zero() && op == FieldOperation::Div {
assert_eq!(
*a,
BigUint::zero(),
"division by zero is allowed only when dividing zero"
);
}
let result = match op {
FieldOperation::Sub => {
let result = (modulus.clone() + a - b) % modulus;
self.populate_carry_and_witness(&result, b, FieldOperation::Add, modulus);
self.result = P::to_limbs_field::<F, _>(&result);
result
}
FieldOperation::Div => {
let result =
(a * b.modpow(&(modulus.clone() - 2u32), &modulus.clone())) % modulus.clone();
self.populate_carry_and_witness(&result, b, FieldOperation::Mul, modulus);
self.result = P::to_limbs_field::<F, _>(&result);
result
}
_ => self.populate_carry_and_witness(a, b, op, modulus),
};
record.add_u8_range_checks_field(shard, channel, &self.result.0);
record.add_u8_range_checks_field(shard, channel, &self.carry.0);
record.add_u8_range_checks_field(shard, channel, &self.witness_low.0);
record.add_u8_range_checks_field(shard, channel, &self.witness_high.0);
result
}
pub fn populate(
&mut self,
record: &mut impl ByteRecord,
shard: u32,
channel: u8,
a: &BigUint,
b: &BigUint,
op: FieldOperation,
) -> BigUint {
self.populate_with_modulus(record, shard, channel, a, b, &P::modulus(), op)
}
}
impl<V: Copy, P: FieldParameters> FieldOpCols<V, P> {
#[allow(clippy::too_many_arguments)]
pub fn eval_with_modulus<AB: SP1AirBuilder<Var = V>>(
&self,
builder: &mut AB,
a: &(impl Into<Polynomial<AB::Expr>> + Clone),
b: &(impl Into<Polynomial<AB::Expr>> + Clone),
modulus: &(impl Into<Polynomial<AB::Expr>> + Clone),
op: FieldOperation,
shard: impl Into<AB::Expr> + Clone,
channel: impl Into<AB::Expr> + Clone,
is_real: impl Into<AB::Expr> + Clone,
) where
V: Into<AB::Expr>,
Limbs<V, P::Limbs>: Copy,
{
let p_a_param: Polynomial<AB::Expr> = (a).clone().into();
let p_b: Polynomial<AB::Expr> = (b).clone().into();
let p_modulus: Polynomial<AB::Expr> = (modulus).clone().into();
let (p_a, p_result): (Polynomial<_>, Polynomial<_>) = match op {
FieldOperation::Add | FieldOperation::Mul => (p_a_param, self.result.into()),
FieldOperation::Sub | FieldOperation::Div => (self.result.into(), p_a_param),
};
let p_carry: Polynomial<<AB as AirBuilder>::Expr> = self.carry.into();
let p_op = match op {
FieldOperation::Add | FieldOperation::Sub => p_a + p_b,
FieldOperation::Mul | FieldOperation::Div => p_a * p_b,
};
let p_op_minus_result: Polynomial<AB::Expr> = p_op - &p_result;
let p_vanishing = p_op_minus_result - &(&p_carry * &p_modulus);
let p_witness_low = self.witness_low.0.iter().into();
let p_witness_high = self.witness_high.0.iter().into();
eval_field_operation::<AB, P>(builder, &p_vanishing, &p_witness_low, &p_witness_high);
builder.slice_range_check_u8(
&self.result.0,
shard.clone(),
channel.clone(),
is_real.clone(),
);
builder.slice_range_check_u8(
&self.carry.0,
shard.clone(),
channel.clone(),
is_real.clone(),
);
builder.slice_range_check_u8(
p_witness_low.coefficients(),
shard.clone(),
channel.clone(),
is_real.clone(),
);
builder.slice_range_check_u8(
p_witness_high.coefficients(),
shard.clone(),
channel.clone(),
is_real,
);
}
#[allow(clippy::too_many_arguments)]
pub fn eval<AB: SP1AirBuilder<Var = V>>(
&self,
builder: &mut AB,
a: &(impl Into<Polynomial<AB::Expr>> + Clone),
b: &(impl Into<Polynomial<AB::Expr>> + Clone),
op: FieldOperation,
shard: impl Into<AB::Expr> + Clone,
channel: impl Into<AB::Expr> + Clone,
is_real: impl Into<AB::Expr> + Clone,
) where
V: Into<AB::Expr>,
Limbs<V, P::Limbs>: Copy,
{
let p_limbs = Polynomial::from_iter(P::modulus_field_iter::<AB::F>().map(AB::Expr::from));
self.eval_with_modulus::<AB>(builder, a, b, &p_limbs, op, shard, channel, is_real);
}
}
#[cfg(test)]
mod tests {
use num::BigUint;
use p3_air::BaseAir;
use p3_field::{Field, PrimeField32};
use super::{FieldOpCols, FieldOperation, Limbs};
use crate::air::MachineAir;
use crate::bytes::event::ByteRecord;
use crate::operations::field::params::FieldParameters;
use crate::runtime::Program;
use crate::stark::StarkGenericConfig;
use crate::utils::ec::edwards::ed25519::Ed25519BaseField;
use crate::utils::ec::weierstrass::secp256k1::Secp256k1BaseField;
use crate::utils::{
pad_to_power_of_two, uni_stark_prove as prove, uni_stark_verify as verify,
BabyBearPoseidon2,
};
use crate::{air::SP1AirBuilder, runtime::ExecutionRecord};
use core::borrow::{Borrow, BorrowMut};
use num::bigint::RandBigInt;
use p3_air::Air;
use p3_baby_bear::BabyBear;
use p3_field::AbstractField;
use p3_matrix::dense::RowMajorMatrix;
use p3_matrix::Matrix;
use rand::thread_rng;
use sp1_derive::AlignedBorrow;
use std::mem::size_of;
#[derive(AlignedBorrow, Debug, Clone)]
pub struct TestCols<T, P: FieldParameters> {
pub a: Limbs<T, P::Limbs>,
pub b: Limbs<T, P::Limbs>,
pub a_op_b: FieldOpCols<T, P>,
}
pub const NUM_TEST_COLS: usize = size_of::<TestCols<u8, Secp256k1BaseField>>();
struct FieldOpChip<P: FieldParameters> {
pub operation: FieldOperation,
pub _phantom: std::marker::PhantomData<P>,
}
impl<P: FieldParameters> FieldOpChip<P> {
pub const fn new(operation: FieldOperation) -> Self {
Self {
operation,
_phantom: std::marker::PhantomData,
}
}
}
impl<F: PrimeField32, P: FieldParameters> MachineAir<F> for FieldOpChip<P> {
type Record = ExecutionRecord;
type Program = Program;
fn name(&self) -> String {
format!("FieldOp{:?}", self.operation)
}
fn generate_trace(
&self,
_: &ExecutionRecord,
output: &mut ExecutionRecord,
) -> RowMajorMatrix<F> {
let mut rng = thread_rng();
let num_rows = 1 << 8;
let mut operands: Vec<(BigUint, BigUint)> = (0..num_rows - 5)
.map(|_| {
let a = rng.gen_biguint(256) % &P::modulus();
let b = rng.gen_biguint(256) % &P::modulus();
(a, b)
})
.collect();
operands.extend(vec![
(BigUint::from(0u32), BigUint::from(0u32)),
(BigUint::from(0u32), BigUint::from(1u32)),
(BigUint::from(1u32), BigUint::from(2u32)),
(BigUint::from(4u32), BigUint::from(5u32)),
(BigUint::from(10u32), BigUint::from(19u32)),
]);
let rows = operands
.iter()
.map(|(a, b)| {
let mut blu_events = Vec::new();
let mut row = [F::zero(); NUM_TEST_COLS];
let cols: &mut TestCols<F, P> = row.as_mut_slice().borrow_mut();
cols.a = P::to_limbs_field::<F, _>(a);
cols.b = P::to_limbs_field::<F, _>(b);
cols.a_op_b
.populate(&mut blu_events, 1, 0, a, b, self.operation);
output.add_byte_lookup_events(blu_events);
row
})
.collect::<Vec<_>>();
let mut trace = RowMajorMatrix::new(
rows.into_iter().flatten().collect::<Vec<_>>(),
NUM_TEST_COLS,
);
pad_to_power_of_two::<NUM_TEST_COLS, F>(&mut trace.values);
trace
}
fn included(&self, _: &Self::Record) -> bool {
true
}
}
impl<F: Field, P: FieldParameters> BaseAir<F> for FieldOpChip<P> {
fn width(&self) -> usize {
NUM_TEST_COLS
}
}
impl<AB, P: FieldParameters> Air<AB> for FieldOpChip<P>
where
AB: SP1AirBuilder,
Limbs<AB::Var, P::Limbs>: Copy,
{
fn eval(&self, builder: &mut AB) {
let main = builder.main();
let local = main.row_slice(0);
let local: &TestCols<AB::Var, P> = (*local).borrow();
local.a_op_b.eval(
builder,
&local.a,
&local.b,
self.operation,
AB::F::one(),
AB::F::zero(),
AB::F::one(),
);
}
}
#[test]
fn generate_trace() {
for op in [
FieldOperation::Add,
FieldOperation::Mul,
FieldOperation::Sub,
]
.iter()
{
println!("op: {:?}", op);
let chip: FieldOpChip<Ed25519BaseField> = FieldOpChip::new(*op);
let shard = ExecutionRecord::default();
let _: RowMajorMatrix<BabyBear> =
chip.generate_trace(&shard, &mut ExecutionRecord::default());
}
}
#[test]
fn prove_babybear() {
let config = BabyBearPoseidon2::new();
for op in [
FieldOperation::Add,
FieldOperation::Sub,
FieldOperation::Mul,
FieldOperation::Div,
]
.iter()
{
println!("op: {:?}", op);
let mut challenger = config.challenger();
let chip: FieldOpChip<Ed25519BaseField> = FieldOpChip::new(*op);
let shard = ExecutionRecord::default();
let trace: RowMajorMatrix<BabyBear> =
chip.generate_trace(&shard, &mut ExecutionRecord::default());
let proof = prove::<BabyBearPoseidon2, _>(&config, &chip, &mut challenger, trace);
let mut challenger = config.challenger();
verify(&config, &chip, &mut challenger, &proof).unwrap();
}
}
}