use std::fmt::Debug;
use num::BigUint;
use p3_air::AirBuilder;
use p3_field::PrimeField32;
use sp1_derive::AlignedBorrow;
use super::field_op::FieldOpCols;
use super::params::{limbs_from_vec, Limbs};
use super::range::FieldLtCols;
use crate::air::SP1AirBuilder;
use crate::bytes::event::ByteRecord;
use crate::bytes::{ByteLookupEvent, ByteOpcode};
use crate::operations::field::params::FieldParameters;
use p3_field::AbstractField;
#[derive(Debug, Clone, AlignedBorrow)]
#[repr(C)]
pub struct FieldSqrtCols<T, P: FieldParameters> {
pub multiplication: FieldOpCols<T, P>,
pub range: FieldLtCols<T, P>,
pub lsb: T,
}
impl<F: PrimeField32, P: FieldParameters> FieldSqrtCols<F, P> {
pub fn populate(
&mut self,
record: &mut impl ByteRecord,
shard: u32,
channel: u8,
a: &BigUint,
sqrt_fn: impl Fn(&BigUint) -> BigUint,
) -> BigUint {
let modulus = P::modulus();
assert!(a < &modulus);
let sqrt = sqrt_fn(a);
let sqrt_squared = self.multiplication.populate(
record,
shard,
channel,
&sqrt,
&sqrt,
super::field_op::FieldOperation::Mul,
);
assert_eq!(sqrt_squared, a.clone());
self.multiplication.result = P::to_limbs_field::<F, _>(&sqrt);
self.range.populate(record, shard, channel, &sqrt, &modulus);
let sqrt_bytes = P::to_limbs(&sqrt);
self.lsb = F::from_canonical_u8(sqrt_bytes[0] & 1);
let and_event = ByteLookupEvent {
shard,
channel,
opcode: ByteOpcode::AND,
a1: self.lsb.as_canonical_u32() as u16,
a2: 0,
b: sqrt_bytes[0],
c: 1,
};
record.add_byte_lookup_event(and_event);
record.add_u8_range_checks(
shard,
channel,
self.multiplication
.result
.0
.as_slice()
.iter()
.map(|x| x.as_canonical_u32() as u8)
.collect::<Vec<_>>()
.as_slice(),
);
sqrt
}
}
impl<V: Copy, P: FieldParameters> FieldSqrtCols<V, P>
where
Limbs<V, P::Limbs>: Copy,
{
pub fn eval<AB: SP1AirBuilder<Var = V>>(
&self,
builder: &mut AB,
a: &Limbs<AB::Var, P::Limbs>,
is_odd: impl Into<AB::Expr>,
shard: impl Into<AB::Expr> + Clone,
channel: impl Into<AB::Expr> + Clone,
is_real: impl Into<AB::Expr> + Clone,
) where
V: Into<AB::Expr>,
{
let sqrt = self.multiplication.result;
let mut multiplication = self.multiplication.clone();
multiplication.result = *a;
multiplication.eval(
builder,
&sqrt,
&sqrt,
super::field_op::FieldOperation::Mul,
shard.clone(),
channel.clone(),
is_real.clone(),
);
let modulus_limbs = P::to_limbs_field_vec(&P::modulus());
self.range.eval(
builder,
&sqrt,
&limbs_from_vec::<AB::Expr, P::Limbs, AB::F>(modulus_limbs),
shard.clone(),
channel.clone(),
is_real.clone(),
);
builder.slice_range_check_u8(
sqrt.0.as_slice(),
shard.clone(),
channel.clone(),
is_real.clone(),
);
builder.assert_bool(self.lsb);
builder.when(is_real.clone()).assert_eq(self.lsb, is_odd);
builder.send_byte(
ByteOpcode::AND.as_field::<AB::F>(),
self.lsb,
sqrt[0],
AB::F::one(),
shard,
channel,
is_real,
);
}
}
#[cfg(test)]
mod tests {
use num::{BigUint, One, Zero};
use p3_air::BaseAir;
use p3_field::{Field, PrimeField32};
use super::{FieldSqrtCols, Limbs};
use crate::air::MachineAir;
use crate::bytes::event::ByteRecord;
use crate::operations::field::params::FieldParameters;
use crate::runtime::Program;
use crate::stark::StarkGenericConfig;
use crate::utils::ec::edwards::ed25519::{ed25519_sqrt, Ed25519BaseField};
use crate::utils::{pad_to_power_of_two, BabyBearPoseidon2};
use crate::utils::{uni_stark_prove as prove, uni_stark_verify as verify};
use crate::{air::SP1AirBuilder, runtime::ExecutionRecord};
use core::borrow::{Borrow, BorrowMut};
use core::mem::size_of;
use num::bigint::RandBigInt;
use p3_air::Air;
use p3_baby_bear::BabyBear;
use p3_field::AbstractField;
use p3_matrix::dense::RowMajorMatrix;
use p3_matrix::Matrix;
use rand::thread_rng;
use sp1_derive::AlignedBorrow;
#[derive(AlignedBorrow, Debug)]
pub struct TestCols<T, P: FieldParameters> {
pub a: Limbs<T, P::Limbs>,
pub sqrt: FieldSqrtCols<T, P>,
}
pub const NUM_TEST_COLS: usize = size_of::<TestCols<u8, Ed25519BaseField>>();
struct EdSqrtChip<P: FieldParameters> {
pub _phantom: std::marker::PhantomData<P>,
}
impl<P: FieldParameters> EdSqrtChip<P> {
pub const fn new() -> Self {
Self {
_phantom: std::marker::PhantomData,
}
}
}
impl<F: PrimeField32, P: FieldParameters> MachineAir<F> for EdSqrtChip<P> {
type Record = ExecutionRecord;
type Program = Program;
fn name(&self) -> String {
"EdSqrtChip".to_string()
}
fn generate_trace(
&self,
_: &ExecutionRecord,
output: &mut ExecutionRecord,
) -> RowMajorMatrix<F> {
let mut rng = thread_rng();
let num_rows = 1 << 8;
let mut operands: Vec<BigUint> = (0..num_rows - 2)
.map(|_| {
let a = rng.gen_biguint(256);
let sq = a.clone() * a.clone();
sq % &Ed25519BaseField::modulus()
})
.collect();
operands.extend(vec![BigUint::zero(), BigUint::one()]);
let rows = operands
.iter()
.map(|a| {
let mut blu_events = Vec::new();
let mut row = [F::zero(); NUM_TEST_COLS];
let cols: &mut TestCols<F, P> = row.as_mut_slice().borrow_mut();
cols.a = P::to_limbs_field::<F, _>(a);
cols.sqrt.populate(&mut blu_events, 1, 0, a, ed25519_sqrt);
output.add_byte_lookup_events(blu_events);
row
})
.collect::<Vec<_>>();
let mut trace = RowMajorMatrix::new(
rows.into_iter().flatten().collect::<Vec<_>>(),
NUM_TEST_COLS,
);
pad_to_power_of_two::<NUM_TEST_COLS, F>(&mut trace.values);
trace
}
fn included(&self, _: &Self::Record) -> bool {
true
}
}
impl<F: Field, P: FieldParameters> BaseAir<F> for EdSqrtChip<P> {
fn width(&self) -> usize {
NUM_TEST_COLS
}
}
impl<AB, P: FieldParameters> Air<AB> for EdSqrtChip<P>
where
AB: SP1AirBuilder,
Limbs<AB::Var, P::Limbs>: Copy,
{
fn eval(&self, builder: &mut AB) {
let main = builder.main();
let local = main.row_slice(0);
let local: &TestCols<AB::Var, P> = (*local).borrow();
local.sqrt.eval(
builder,
&local.a,
AB::F::zero(),
AB::F::one(),
AB::F::zero(),
AB::F::one(),
);
}
}
#[test]
fn generate_trace() {
let chip: EdSqrtChip<Ed25519BaseField> = EdSqrtChip::new();
let shard = ExecutionRecord::default();
let _: RowMajorMatrix<BabyBear> =
chip.generate_trace(&shard, &mut ExecutionRecord::default());
}
#[test]
fn prove_babybear() {
let config = BabyBearPoseidon2::new();
let mut challenger = config.challenger();
let chip: EdSqrtChip<Ed25519BaseField> = EdSqrtChip::new();
let shard = ExecutionRecord::default();
let trace: RowMajorMatrix<BabyBear> =
chip.generate_trace(&shard, &mut ExecutionRecord::default());
let proof = prove::<BabyBearPoseidon2, _>(&config, &chip, &mut challenger, trace);
let mut challenger = config.challenger();
verify(&config, &chip, &mut challenger, &proof).unwrap();
}
}