use std::num::Wrapping;
use crate::{
air::{SP1Operation, SP1OperationBuilder, WordAirBuilder},
operations::{U16MSBOperation, U16toU8OperationSafe, U16toU8OperationSafeInput},
};
use sp1_core_executor::{
events::{ByteLookupEvent, ByteRecord},
ByteOpcode,
};
use sp1_hypercube::{air::SP1AirBuilder, Word};
use slop_air::AirBuilder;
use slop_algebra::{AbstractField, Field};
use sp1_derive::{AlignedBorrow, InputExpr, InputParams, IntoShape, SP1OperationBuilder};
use sp1_primitives::consts::{
u64_to_u16_limbs, BYTE_SIZE, LONG_WORD_BYTE_SIZE, WORD_BYTE_SIZE, WORD_SIZE,
};
use super::{U16MSBOperationInput, U16toU8Operation};
const BYTE_MASK: u8 = 0xff;
pub const fn get_msb(a: [u8; 8]) -> u8 {
((a[7] >> (BYTE_SIZE - 1)) & 1) as u8
}
#[derive(AlignedBorrow, Default, Debug, Clone, Copy, IntoShape, SP1OperationBuilder)]
#[repr(C)]
pub struct MulOperation<T> {
pub carry: [T; LONG_WORD_BYTE_SIZE],
pub product: [T; LONG_WORD_BYTE_SIZE],
pub b_lower_byte: U16toU8Operation<T>,
pub c_lower_byte: U16toU8Operation<T>,
pub b_msb: T,
pub c_msb: T,
pub product_msb: U16MSBOperation<T>,
pub b_sign_extend: T,
pub c_sign_extend: T,
}
impl<F: Field> MulOperation<F> {
pub fn populate(
&mut self,
record: &mut impl ByteRecord,
b_u64: u64,
c_u64: u64,
is_mulh: bool,
is_mulhsu: bool,
is_mulw: bool,
) {
let b_word = b_u64.to_le_bytes();
let c_word = c_u64.to_le_bytes();
let mulw_value = (Wrapping(b_u64 as i32) * Wrapping(c_u64 as i32)).0 as i64 as u64;
let limbs = u64_to_u16_limbs(mulw_value);
if is_mulw {
self.product_msb.populate_msb(record, limbs[1]);
} else {
self.product_msb.msb = F::zero();
}
let mut b = b_word.to_vec();
let mut c = c_word.to_vec();
self.b_lower_byte.populate_u16_to_u8_safe(record, b_u64);
self.c_lower_byte.populate_u16_to_u8_safe(record, c_u64);
{
let b_msb = get_msb(b_word);
self.b_msb = F::from_canonical_u8(b_msb);
let c_msb = get_msb(c_word);
self.c_msb = F::from_canonical_u8(c_msb);
if (is_mulh || is_mulhsu) && b_msb == 1 {
self.b_sign_extend = F::one();
b.resize(LONG_WORD_BYTE_SIZE, BYTE_MASK);
} else {
self.b_sign_extend = F::zero();
}
if is_mulh && c_msb == 1 {
self.c_sign_extend = F::one();
c.resize(LONG_WORD_BYTE_SIZE, BYTE_MASK);
} else {
self.c_sign_extend = F::zero();
}
{
let words = [b_word, c_word];
let mut blu_events: Vec<ByteLookupEvent> = vec![];
for word in words.iter() {
let most_significant_byte = word[WORD_BYTE_SIZE - 1];
blu_events.push(ByteLookupEvent {
opcode: ByteOpcode::MSB,
a: get_msb(*word) as u16,
b: most_significant_byte,
c: 0,
});
}
record.add_byte_lookup_events(blu_events);
}
}
let mut product = [0u32; LONG_WORD_BYTE_SIZE];
for i in 0..b.len() {
for j in 0..c.len() {
if i + j < LONG_WORD_BYTE_SIZE {
product[i + j] += (b[i] as u32) * (c[j] as u32);
}
}
}
let base = (1 << BYTE_SIZE) as u32;
let mut carry = [0u32; LONG_WORD_BYTE_SIZE];
for i in 0..LONG_WORD_BYTE_SIZE {
carry[i] = product[i] / base;
product[i] %= base;
if i + 1 < LONG_WORD_BYTE_SIZE {
product[i + 1] += carry[i];
}
self.carry[i] = F::from_canonical_u32(carry[i]);
}
self.product = product.map(F::from_canonical_u32);
{
record.add_u16_range_checks(&carry.map(|x| x as u16));
record.add_u8_range_checks(&product.map(|x| x as u8));
}
}
#[allow(clippy::too_many_arguments)]
pub fn eval<
AB: SP1AirBuilder
+ SP1OperationBuilder<U16toU8OperationSafe>
+ SP1OperationBuilder<U16MSBOperation<<AB as AirBuilder>::F>>,
>(
builder: &mut AB,
a_word: Word<AB::Expr>,
b_word: Word<AB::Expr>,
c_word: Word<AB::Expr>,
cols: MulOperation<AB::Var>,
is_real: AB::Expr,
is_mul: AB::Expr,
is_mulh: AB::Expr,
is_mulw: AB::Expr,
is_mulhu: AB::Expr,
is_mulhsu: AB::Expr,
) {
let zero: AB::Expr = AB::F::zero().into();
let base = AB::F::from_canonical_u32(1 << 8);
let one: AB::Expr = AB::F::one().into();
let byte_mask = AB::F::from_canonical_u8(BYTE_MASK);
let b_input = U16toU8OperationSafeInput::new(b_word.0, cols.b_lower_byte, is_real.clone());
let b = U16toU8OperationSafe::eval(builder, b_input);
let c_input = U16toU8OperationSafeInput::new(c_word.0, cols.c_lower_byte, is_real.clone());
let c = U16toU8OperationSafe::eval(builder, c_input);
let msb_opcode = AB::F::from_canonical_u32(ByteOpcode::MSB as u32);
let (b_msb, c_msb) = {
let msb_pairs = [
(cols.b_msb, b[WORD_BYTE_SIZE - 1].clone()),
(cols.c_msb, c[WORD_BYTE_SIZE - 1].clone()),
];
for msb_pair in msb_pairs.iter() {
let msb = msb_pair.0;
let byte = msb_pair.1.clone();
builder.send_byte(msb_opcode, msb, byte, zero.clone(), is_real.clone());
}
(cols.b_msb, cols.c_msb)
};
<U16MSBOperation<AB::F> as SP1Operation<AB>>::eval(
builder,
U16MSBOperationInput::new(a_word.0[1].clone(), cols.product_msb, is_mulw.clone()),
);
let (b_sign_extend, c_sign_extend) = {
let is_b_i64 = is_mulh.clone() + is_mulhsu.clone();
let is_c_i64 = is_mulh.clone();
builder.assert_eq(cols.b_sign_extend, is_b_i64 * b_msb);
builder.assert_eq(cols.c_sign_extend, is_c_i64 * c_msb);
(cols.b_sign_extend, cols.c_sign_extend)
};
let (b, c) = {
let mut b_extended: Vec<AB::Expr> = vec![AB::F::zero().into(); LONG_WORD_BYTE_SIZE];
let mut c_extended: Vec<AB::Expr> = vec![AB::F::zero().into(); LONG_WORD_BYTE_SIZE];
for i in 0..LONG_WORD_BYTE_SIZE {
if i < WORD_BYTE_SIZE {
b_extended[i] = b[i].clone();
c_extended[i] = c[i].clone();
} else {
b_extended[i] = b_sign_extend * byte_mask;
c_extended[i] = c_sign_extend * byte_mask;
}
}
(b_extended, c_extended)
};
let mut m: Vec<AB::Expr> = vec![AB::F::zero().into(); LONG_WORD_BYTE_SIZE];
for i in 0..LONG_WORD_BYTE_SIZE {
for j in 0..LONG_WORD_BYTE_SIZE {
if i + j < LONG_WORD_BYTE_SIZE {
m[i + j] = m[i + j].clone() + b[i].clone() * c[j].clone();
}
}
}
let product = {
for i in 0..LONG_WORD_BYTE_SIZE {
if i == 0 {
builder
.when(is_real.clone())
.assert_eq(cols.product[i], m[i].clone() - cols.carry[i] * base);
} else {
builder.when(is_real.clone()).assert_eq(
cols.product[i],
m[i].clone() + cols.carry[i - 1] - cols.carry[i] * base,
);
}
}
cols.product
};
{
let is_lower = is_mul.clone();
let is_upper = is_mulh.clone() + is_mulhu.clone() + is_mulhsu.clone();
let is_word = is_mulw.clone();
let u16_max = AB::F::from_canonical_u32((1 << 16) - 1);
for i in 0..WORD_SIZE {
if i < WORD_SIZE / 2 {
builder.when(is_word.clone()).assert_eq(
product[2 * i] + product[2 * i + 1] * AB::F::from_canonical_u16(1 << 8),
a_word[i].clone(),
);
} else {
builder
.when(is_word.clone())
.assert_eq(cols.product_msb.msb * u16_max, a_word[i].clone());
}
builder.when(is_lower.clone()).assert_eq(
product[2 * i] + product[2 * i + 1] * AB::F::from_canonical_u16(1 << 8),
a_word[i].clone(),
);
builder.when(is_upper.clone()).assert_eq(
product[2 * i + WORD_BYTE_SIZE]
+ product[2 * i + 1 + WORD_BYTE_SIZE] * AB::F::from_canonical_u16(1 << 8),
a_word[i].clone(),
);
}
}
{
let booleans = [
cols.b_msb.into(),
cols.c_msb.into(),
cols.b_sign_extend.into(),
cols.c_sign_extend.into(),
is_mul.clone(),
is_mulh.clone(),
is_mulhu.clone(),
is_mulhsu.clone(),
is_mulw.clone(),
is_mul.clone()
+ is_mulh.clone()
+ is_mulhu.clone()
+ is_mulhsu.clone()
+ is_mulw.clone(),
is_real.clone(),
];
for boolean in booleans.iter() {
builder.assert_bool(boolean.clone());
}
}
builder.when(cols.b_sign_extend).assert_eq(cols.b_msb, one.clone());
builder.when(cols.c_sign_extend).assert_eq(cols.c_msb, one.clone());
{
builder.slice_range_check_u16(&cols.carry, is_real.clone());
builder.slice_range_check_u8(&cols.product, is_real.clone());
}
}
}
#[derive(Debug, Clone, InputExpr, InputParams)]
pub struct MulOperationInput<AB: SP1AirBuilder> {
pub a_word: Word<AB::Expr>,
pub b_word: Word<AB::Expr>,
pub c_word: Word<AB::Expr>,
pub cols: MulOperation<AB::Var>,
pub is_real: AB::Expr,
pub is_mul: AB::Expr,
pub is_mulh: AB::Expr,
pub is_mulw: AB::Expr,
pub is_mulhu: AB::Expr,
pub is_mulhsu: AB::Expr,
}
impl<AB> SP1Operation<AB> for MulOperation<AB::F>
where
AB: SP1AirBuilder
+ SP1OperationBuilder<U16toU8OperationSafe>
+ SP1OperationBuilder<U16MSBOperation<<AB as AirBuilder>::F>>,
{
type Input = MulOperationInput<AB>;
type Output = ();
fn lower(builder: &mut AB, input: Self::Input) -> Self::Output {
Self::eval(
builder,
input.a_word,
input.b_word,
input.c_word,
input.cols,
input.is_real,
input.is_mul,
input.is_mulh,
input.is_mulw,
input.is_mulhu,
input.is_mulhsu,
);
}
}