use crate::{
memory::{value_as_limbs, MemoryReadCols, MemoryWriteCols},
operations::field::field_op::FieldOpCols,
};
use crate::{
air::MemoryAirBuilder,
operations::{field::range::FieldLtCols, IsZeroOperation},
utils::{
limbs_from_access, limbs_from_prev_access, pad_rows_fixed, words_to_bytes_le,
words_to_bytes_le_vec,
},
};
use generic_array::GenericArray;
use num::{BigUint, One, Zero};
use p3_air::{Air, AirBuilder, BaseAir};
use p3_field::{AbstractField, PrimeField32};
use p3_matrix::{dense::RowMajorMatrix, Matrix};
use sp1_core_executor::{
events::{ByteRecord, FieldOperation, PrecompileEvent},
syscalls::SyscallCode,
ExecutionRecord, Program,
};
use sp1_curves::{
params::{Limbs, NumLimbs, NumWords},
uint256::U256Field,
};
use sp1_derive::AlignedBorrow;
use sp1_stark::{
air::{BaseAirBuilder, InteractionScope, MachineAir, Polynomial, SP1AirBuilder},
MachineRecord,
};
use std::{
borrow::{Borrow, BorrowMut},
mem::size_of,
};
use typenum::Unsigned;
const NUM_COLS: usize = size_of::<Uint256MulCols<u8>>();
#[derive(Default)]
pub struct Uint256MulChip;
impl Uint256MulChip {
pub const fn new() -> Self {
Self
}
}
type WordsFieldElement = <U256Field as NumWords>::WordsFieldElement;
const WORDS_FIELD_ELEMENT: usize = WordsFieldElement::USIZE;
#[derive(Debug, Clone, AlignedBorrow)]
#[repr(C)]
pub struct Uint256MulCols<T> {
pub shard: T,
pub clk: T,
pub nonce: T,
pub x_ptr: T,
pub y_ptr: T,
pub x_memory: GenericArray<MemoryWriteCols<T>, WordsFieldElement>,
pub y_memory: GenericArray<MemoryReadCols<T>, WordsFieldElement>,
pub modulus_memory: GenericArray<MemoryReadCols<T>, WordsFieldElement>,
pub modulus_is_zero: IsZeroOperation<T>,
pub modulus_is_not_zero: T,
pub output: FieldOpCols<T, U256Field>,
pub output_range_check: FieldLtCols<T, U256Field>,
pub is_real: T,
}
impl<F: PrimeField32> MachineAir<F> for Uint256MulChip {
type Record = ExecutionRecord;
type Program = Program;
fn name(&self) -> String {
"Uint256MulMod".to_string()
}
fn generate_trace(
&self,
input: &ExecutionRecord,
output: &mut ExecutionRecord,
) -> RowMajorMatrix<F> {
let rows_and_records = input
.get_precompile_events(SyscallCode::UINT256_MUL)
.chunks(1)
.map(|events| {
let mut records = ExecutionRecord::default();
let mut new_byte_lookup_events = Vec::new();
let rows = events
.iter()
.map(|(_, event)| {
let event = if let PrecompileEvent::Uint256Mul(event) = event {
event
} else {
unreachable!()
};
let mut row: [F; NUM_COLS] = [F::zero(); NUM_COLS];
let cols: &mut Uint256MulCols<F> = row.as_mut_slice().borrow_mut();
let x = BigUint::from_bytes_le(&words_to_bytes_le::<32>(&event.x));
let y = BigUint::from_bytes_le(&words_to_bytes_le::<32>(&event.y));
let modulus =
BigUint::from_bytes_le(&words_to_bytes_le::<32>(&event.modulus));
cols.is_real = F::one();
cols.shard = F::from_canonical_u32(event.shard);
cols.clk = F::from_canonical_u32(event.clk);
cols.x_ptr = F::from_canonical_u32(event.x_ptr);
cols.y_ptr = F::from_canonical_u32(event.y_ptr);
for i in 0..WORDS_FIELD_ELEMENT {
cols.x_memory[i]
.populate(event.x_memory_records[i], &mut new_byte_lookup_events);
cols.y_memory[i]
.populate(event.y_memory_records[i], &mut new_byte_lookup_events);
cols.modulus_memory[i].populate(
event.modulus_memory_records[i],
&mut new_byte_lookup_events,
);
}
let modulus_bytes = words_to_bytes_le_vec(&event.modulus);
let modulus_byte_sum = modulus_bytes.iter().map(|b| *b as u32).sum::<u32>();
IsZeroOperation::populate(&mut cols.modulus_is_zero, modulus_byte_sum);
let effective_modulus =
if modulus.is_zero() { BigUint::one() << 256 } else { modulus.clone() };
let result = cols.output.populate_with_modulus(
&mut new_byte_lookup_events,
event.shard,
&x,
&y,
&effective_modulus,
FieldOperation::Mul,
);
cols.modulus_is_not_zero = F::one() - cols.modulus_is_zero.result;
if cols.modulus_is_not_zero == F::one() {
cols.output_range_check.populate(
&mut new_byte_lookup_events,
event.shard,
&result,
&effective_modulus,
);
}
row
})
.collect::<Vec<_>>();
records.add_byte_lookup_events(new_byte_lookup_events);
(rows, records)
})
.collect::<Vec<_>>();
let mut rows = Vec::new();
for (row, mut record) in rows_and_records {
rows.extend(row);
output.append(&mut record);
}
pad_rows_fixed(
&mut rows,
|| {
let mut row: [F; NUM_COLS] = [F::zero(); NUM_COLS];
let cols: &mut Uint256MulCols<F> = row.as_mut_slice().borrow_mut();
let x = BigUint::zero();
let y = BigUint::zero();
cols.output.populate(&mut vec![], 0, &x, &y, FieldOperation::Mul);
row
},
input.fixed_log2_rows::<F, _>(self),
);
let mut trace =
RowMajorMatrix::new(rows.into_iter().flatten().collect::<Vec<_>>(), NUM_COLS);
for i in 0..trace.height() {
let cols: &mut Uint256MulCols<F> =
trace.values[i * NUM_COLS..(i + 1) * NUM_COLS].borrow_mut();
cols.nonce = F::from_canonical_usize(i);
}
trace
}
fn included(&self, shard: &Self::Record) -> bool {
if let Some(shape) = shard.shape.as_ref() {
shape.included::<F, _>(self)
} else {
!shard.get_precompile_events(SyscallCode::UINT256_MUL).is_empty()
}
}
}
impl<F> BaseAir<F> for Uint256MulChip {
fn width(&self) -> usize {
NUM_COLS
}
}
impl<AB> Air<AB> for Uint256MulChip
where
AB: SP1AirBuilder,
Limbs<AB::Var, <U256Field as NumLimbs>::Limbs>: Copy,
{
fn eval(&self, builder: &mut AB) {
let main = builder.main();
let local = main.row_slice(0);
let local: &Uint256MulCols<AB::Var> = (*local).borrow();
let next = main.row_slice(1);
let next: &Uint256MulCols<AB::Var> = (*next).borrow();
builder.when_first_row().assert_zero(local.nonce);
builder.when_transition().assert_eq(local.nonce + AB::Expr::one(), next.nonce);
let x_limbs = limbs_from_prev_access(&local.x_memory);
let y_limbs = limbs_from_access(&local.y_memory);
let modulus_limbs = limbs_from_access(&local.modulus_memory);
let modulus_byte_sum =
modulus_limbs.0.iter().fold(AB::Expr::zero(), |acc, &limb| acc + limb);
IsZeroOperation::<AB::F>::eval(
builder,
modulus_byte_sum,
local.modulus_is_zero,
local.is_real.into(),
);
let modulus_is_zero = local.modulus_is_zero.result;
let mut coeff_2_256 = Vec::new();
coeff_2_256.resize(32, AB::Expr::zero());
coeff_2_256.push(AB::Expr::one());
let modulus_polynomial: Polynomial<AB::Expr> = modulus_limbs.into();
let p_modulus: Polynomial<AB::Expr> = modulus_polynomial
* (AB::Expr::one() - modulus_is_zero.into())
+ Polynomial::from_coefficients(&coeff_2_256) * modulus_is_zero.into();
local.output.eval_with_modulus(
builder,
&x_limbs,
&y_limbs,
&p_modulus,
FieldOperation::Mul,
local.is_real,
);
local.output_range_check.eval(
builder,
&local.output.result,
&modulus_limbs,
local.modulus_is_not_zero,
);
builder.assert_eq(
local.modulus_is_not_zero,
local.is_real * (AB::Expr::one() - modulus_is_zero.into()),
);
builder
.when(local.is_real)
.assert_all_eq(local.output.result, value_as_limbs(&local.x_memory));
builder.eval_memory_access_slice(
local.shard,
local.clk.into() + AB::Expr::one(),
local.x_ptr,
&local.x_memory,
local.is_real,
);
builder.eval_memory_access_slice(
local.shard,
local.clk.into(),
local.y_ptr,
&[local.y_memory, local.modulus_memory].concat(),
local.is_real,
);
builder.receive_syscall(
local.shard,
local.clk,
local.nonce,
AB::F::from_canonical_u32(SyscallCode::UINT256_MUL.syscall_id()),
local.x_ptr,
local.y_ptr,
local.is_real,
InteractionScope::Local,
);
builder.assert_bool(local.is_real);
}
}