use crate::{
air::SP1Operation,
memory::MemoryAccessColsU8,
operations::{field::field_op::FieldOpCols, AddrAddOperation, IsZeroOperationInput},
};
use crate::{
air::SP1CoreAirBuilder,
operations::{field::range::FieldLtCols, IsZeroOperation, SyscallAddrOperation},
utils::{limbs_to_words, next_multiple_of_32, words_to_bytes_le, words_to_bytes_le_vec},
};
use generic_array::GenericArray;
use itertools::Itertools;
use num::{BigUint, One, Zero};
use slop_air::{Air, BaseAir};
use slop_algebra::{AbstractField, PrimeField32};
use slop_matrix::Matrix;
use sp1_core_executor::{
events::{ByteRecord, FieldOperation, MemoryRecordEnum, PrecompileEvent},
ExecutionRecord, Program, SyscallCode,
};
use sp1_curves::{
params::{Limbs, NumLimbs, NumWords},
uint256::U256Field,
};
use sp1_derive::AlignedBorrow;
use sp1_hypercube::{
air::{InteractionScope, MachineAir},
Word,
};
use sp1_primitives::polynomial::Polynomial;
use std::{
borrow::{Borrow, BorrowMut},
mem::{size_of, MaybeUninit},
};
use typenum::Unsigned;
const NUM_COLS: usize = size_of::<Uint256MulCols<u8>>();
#[derive(Default)]
pub struct Uint256MulChip;
impl Uint256MulChip {
pub const fn new() -> Self {
Self
}
}
type WordsFieldElement = <U256Field as NumWords>::WordsFieldElement;
const WORDS_FIELD_ELEMENT: usize = WordsFieldElement::USIZE;
#[derive(Debug, Clone, AlignedBorrow)]
#[repr(C)]
pub struct Uint256MulCols<T> {
pub clk_high: T,
pub clk_low: T,
pub x_ptr: SyscallAddrOperation<T>,
pub y_ptr: SyscallAddrOperation<T>,
pub x_addrs: [AddrAddOperation<T>; WORDS_FIELD_ELEMENT],
pub y_and_modulus_addrs: [AddrAddOperation<T>; 2 * WORDS_FIELD_ELEMENT],
pub x_memory: GenericArray<MemoryAccessColsU8<T>, WordsFieldElement>,
pub y_memory: GenericArray<MemoryAccessColsU8<T>, WordsFieldElement>,
pub modulus_memory: GenericArray<MemoryAccessColsU8<T>, WordsFieldElement>,
pub modulus_is_zero: IsZeroOperation<T>,
pub modulus_is_not_zero: T,
pub output: FieldOpCols<T, U256Field>,
pub output_range_check: FieldLtCols<T, U256Field>,
pub is_real: T,
}
impl<F: PrimeField32> MachineAir<F> for Uint256MulChip {
type Record = ExecutionRecord;
type Program = Program;
fn name(&self) -> &'static str {
"Uint256MulMod"
}
fn num_rows(&self, input: &Self::Record) -> Option<usize> {
let nb_rows = input.get_precompile_events(SyscallCode::UINT256_MUL).len();
let size_log2 = input.fixed_log2_rows::<F, _>(self);
let padded_nb_rows = next_multiple_of_32(nb_rows, size_log2);
Some(padded_nb_rows)
}
fn generate_trace_into(
&self,
input: &ExecutionRecord,
output: &mut ExecutionRecord,
buffer: &mut [MaybeUninit<F>],
) {
let padded_nb_rows = <Uint256MulChip as MachineAir<F>>::num_rows(self, input).unwrap();
let events = input.get_precompile_events(SyscallCode::UINT256_MUL);
let chunk_size = 1;
let num_event_rows = events.len();
unsafe {
let padding_start = num_event_rows * NUM_COLS;
let padding_size = (padded_nb_rows - num_event_rows) * NUM_COLS;
if padding_size > 0 {
core::ptr::write_bytes(buffer[padding_start..].as_mut_ptr(), 0, padding_size);
}
}
let buffer_ptr = buffer.as_mut_ptr() as *mut F;
let buffer_as_slice =
unsafe { core::slice::from_raw_parts_mut(buffer_ptr, num_event_rows * NUM_COLS) };
let mut new_byte_lookup_events = Vec::new();
buffer_as_slice.chunks_exact_mut(chunk_size * NUM_COLS).enumerate().for_each(
|(i, rows)| {
rows.chunks_mut(NUM_COLS).enumerate().for_each(|(j, row)| {
let idx = i * chunk_size + j;
if idx < events.len() {
let event = &events[idx].1;
let event = if let PrecompileEvent::Uint256Mul(event) = event {
event
} else {
unreachable!()
};
unsafe {
core::ptr::write_bytes(row.as_mut_ptr(), 0, NUM_COLS);
}
let cols: &mut Uint256MulCols<F> = row.borrow_mut();
let x = BigUint::from_bytes_le(&words_to_bytes_le::<32>(&event.x));
let y = BigUint::from_bytes_le(&words_to_bytes_le::<32>(&event.y));
let modulus =
BigUint::from_bytes_le(&words_to_bytes_le::<32>(&event.modulus));
cols.is_real = F::one();
cols.clk_high = F::from_canonical_u32((event.clk >> 24) as u32);
cols.clk_low = F::from_canonical_u32((event.clk & 0xFFFFFF) as u32);
cols.x_ptr.populate(&mut new_byte_lookup_events, event.x_ptr, 32);
cols.y_ptr.populate(&mut new_byte_lookup_events, event.y_ptr, 64);
let modulus_ptr = event.y_ptr + WORDS_FIELD_ELEMENT as u64 * 8;
for i in 0..WORDS_FIELD_ELEMENT {
let x_memory_record =
MemoryRecordEnum::Write(event.x_memory_records[i]);
let y_memory_record = MemoryRecordEnum::Read(event.y_memory_records[i]);
let modulus_memory_record =
MemoryRecordEnum::Read(event.modulus_memory_records[i]);
cols.x_memory[i].populate(x_memory_record, &mut new_byte_lookup_events);
cols.y_memory[i].populate(y_memory_record, &mut new_byte_lookup_events);
cols.modulus_memory[i]
.populate(modulus_memory_record, &mut new_byte_lookup_events);
cols.x_addrs[i].populate(
&mut new_byte_lookup_events,
event.x_ptr,
8 * i as u64,
);
cols.y_and_modulus_addrs[i].populate(
&mut new_byte_lookup_events,
event.y_ptr,
8 * i as u64,
);
cols.y_and_modulus_addrs[i + WORDS_FIELD_ELEMENT].populate(
&mut new_byte_lookup_events,
modulus_ptr,
8 * i as u64,
);
}
let modulus_bytes = words_to_bytes_le_vec(&event.modulus);
let modulus_byte_sum = modulus_bytes.iter().map(|b| *b as u64).sum::<u64>();
IsZeroOperation::populate(&mut cols.modulus_is_zero, modulus_byte_sum);
let effective_modulus =
if modulus.is_zero() { BigUint::one() << 256 } else { modulus.clone() };
let result = cols.output.populate_with_modulus(
&mut new_byte_lookup_events,
&x,
&y,
&effective_modulus,
FieldOperation::Mul,
);
cols.modulus_is_not_zero = F::one() - cols.modulus_is_zero.result;
if cols.modulus_is_not_zero == F::one() {
cols.output_range_check.populate(
&mut new_byte_lookup_events,
&result,
&effective_modulus,
);
}
}
})
},
);
for row in num_event_rows..padded_nb_rows {
let row_start = row * NUM_COLS;
let row = unsafe {
core::slice::from_raw_parts_mut(
buffer[row_start..].as_mut_ptr() as *mut F,
NUM_COLS,
)
};
let cols: &mut Uint256MulCols<F> = row.borrow_mut();
let x = BigUint::zero();
let y = BigUint::zero();
cols.output.populate(&mut vec![], &x, &y, FieldOperation::Mul);
}
output.add_byte_lookup_events(new_byte_lookup_events);
}
fn included(&self, shard: &Self::Record) -> bool {
if let Some(shape) = shard.shape.as_ref() {
shape.included::<F, _>(self)
} else {
!shard.get_precompile_events(SyscallCode::UINT256_MUL).is_empty()
}
}
}
impl<F> BaseAir<F> for Uint256MulChip {
fn width(&self) -> usize {
NUM_COLS
}
}
impl<AB> Air<AB> for Uint256MulChip
where
AB: SP1CoreAirBuilder,
Limbs<AB::Var, <U256Field as NumLimbs>::Limbs>: Copy,
{
fn eval(&self, builder: &mut AB) {
let main = builder.main();
let local = main.row_slice(0);
let local: &Uint256MulCols<AB::Var> = (*local).borrow();
let x_limb_vec = builder.generate_limbs(&local.x_memory, local.is_real.into());
let x_limbs: Limbs<AB::Expr, <U256Field as NumLimbs>::Limbs> =
Limbs(x_limb_vec.try_into().expect("failed to convert limbs"));
let y_limb_vec = builder.generate_limbs(&local.y_memory, local.is_real.into());
let y_limbs: Limbs<AB::Expr, <U256Field as NumLimbs>::Limbs> =
Limbs(y_limb_vec.try_into().expect("failed to convert limbs"));
let modulus_limb_vec = builder.generate_limbs(&local.modulus_memory, local.is_real.into());
let modulus_limbs: Limbs<AB::Expr, <U256Field as NumLimbs>::Limbs> =
Limbs(modulus_limb_vec.try_into().expect("failed to convert limbs"));
let modulus_byte_sum =
modulus_limbs.clone().0.iter().fold(AB::Expr::zero(), |acc, limb| acc + limb.clone());
IsZeroOperation::<AB::F>::eval(
builder,
IsZeroOperationInput::new(
modulus_byte_sum,
local.modulus_is_zero,
local.is_real.into(),
),
);
let modulus_is_zero = local.modulus_is_zero.result;
let mut coeff_2_256 = Vec::new();
coeff_2_256.resize(32, AB::Expr::zero());
coeff_2_256.push(AB::Expr::one());
let modulus_polynomial: Polynomial<AB::Expr> = modulus_limbs.clone().into();
let p_modulus: Polynomial<AB::Expr> = modulus_polynomial
* (AB::Expr::one() - modulus_is_zero.into())
+ Polynomial::from_coefficients(&coeff_2_256) * modulus_is_zero.into();
local.output.eval_with_modulus(
builder,
&x_limbs,
&y_limbs,
&p_modulus,
FieldOperation::Mul,
local.is_real,
);
local.output_range_check.eval(
builder,
&local.output.result,
&modulus_limbs.clone(),
local.modulus_is_not_zero,
);
builder.assert_eq(
local.modulus_is_not_zero,
local.is_real * (AB::Expr::one() - modulus_is_zero.into()),
);
let result_words = limbs_to_words::<AB>(local.output.result.0.to_vec());
let x_ptr =
SyscallAddrOperation::<AB::F>::eval(builder, 32, local.x_ptr, local.is_real.into());
let y_ptr =
SyscallAddrOperation::<AB::F>::eval(builder, 64, local.y_ptr, local.is_real.into());
for i in 0..local.x_addrs.len() {
AddrAddOperation::<AB::F>::eval(
builder,
Word([x_ptr[0].into(), x_ptr[1].into(), x_ptr[2].into(), AB::Expr::zero()]),
Word::from(8 * i as u64),
local.x_addrs[i],
local.is_real.into(),
);
}
for i in 0..local.y_and_modulus_addrs.len() {
AddrAddOperation::<AB::F>::eval(
builder,
Word([y_ptr[0].into(), y_ptr[1].into(), y_ptr[2].into(), AB::Expr::zero()]),
Word::from(8 * i as u64),
local.y_and_modulus_addrs[i],
local.is_real.into(),
);
}
builder.eval_memory_access_slice_write(
local.clk_high,
local.clk_low + AB::Expr::one(),
&local.x_addrs.map(|addr| addr.value.map(Into::into)),
&local.x_memory.iter().map(|access| access.memory_access).collect_vec(),
result_words,
local.is_real,
);
builder.eval_memory_access_slice_read(
local.clk_high,
local.clk_low.into(),
&local.y_and_modulus_addrs.map(|addr| addr.value.map(Into::into)),
&[local.y_memory, local.modulus_memory]
.concat()
.iter()
.map(|access| access.memory_access)
.collect_vec(),
local.is_real,
);
builder.receive_syscall(
local.clk_high,
local.clk_low.into(),
AB::F::from_canonical_u32(SyscallCode::UINT256_MUL.syscall_id()),
x_ptr.map(Into::into),
y_ptr.map(Into::into),
local.is_real,
InteractionScope::Local,
);
}
}