use core::{
borrow::{Borrow, BorrowMut},
mem::{size_of, MaybeUninit},
};
use hashbrown::HashMap;
use itertools::Itertools;
use slop_air::{Air, AirBuilder, BaseAir};
use slop_algebra::{AbstractField, Field, PrimeField, PrimeField32};
use slop_matrix::Matrix;
use slop_maybe_rayon::prelude::{ParallelBridge, ParallelIterator, ParallelSlice};
use sp1_core_executor::{
events::{AluEvent, ByteLookupEvent, ByteRecord},
ByteOpcode, ExecutionRecord, Opcode, Program, CLK_INC, PC_INC,
};
use sp1_derive::AlignedBorrow;
use sp1_hypercube::{air::MachineAir, Word};
use sp1_primitives::consts::{u64_to_u16_limbs, WORD_SIZE};
use crate::{
adapter::{
register::alu_type::{ALUTypeReader, ALUTypeReaderInput},
state::{CPUState, CPUStateInput},
},
air::{SP1CoreAirBuilder, SP1Operation},
operations::{U16MSBOperation, U16MSBOperationInput},
utils::next_multiple_of_32,
};
pub const NUM_SHIFT_RIGHT_COLS: usize = size_of::<ShiftRightCols<u8>>();
#[derive(Default)]
pub struct ShiftRightChip;
#[derive(AlignedBorrow, Default, Debug, Clone, Copy)]
#[repr(C)]
pub struct ShiftRightCols<T> {
pub state: CPUState<T>,
pub adapter: ALUTypeReader<T>,
pub a: Word<T>,
pub b_msb: U16MSBOperation<T>,
pub srw_msb: U16MSBOperation<T>,
pub c_bits: [T; 6],
pub sra_msb_v0123: T,
pub v_0123: T,
pub v_012: T,
pub v_01: T,
pub lower_limb: Word<T>,
pub higher_limb: Word<T>,
pub limb_result: [T; WORD_SIZE],
pub shift_u16: [T; 4],
pub is_srl: T,
pub is_sra: T,
pub is_srlw: T,
pub is_sraw: T,
pub is_w_imm: T,
}
impl<F: PrimeField32> MachineAir<F> for ShiftRightChip {
type Record = ExecutionRecord;
type Program = Program;
fn name(&self) -> &'static str {
"ShiftRight"
}
fn num_rows(&self, input: &Self::Record) -> Option<usize> {
let nb_rows = next_multiple_of_32(
input.shift_right_events.len(),
input.fixed_log2_rows::<F, _>(self),
);
Some(nb_rows)
}
fn generate_trace_into(
&self,
input: &ExecutionRecord,
_output: &mut ExecutionRecord,
buffer: &mut [MaybeUninit<F>],
) {
let nb_rows = input.shift_right_events.len();
let padded_nb_rows = <ShiftRightChip as MachineAir<F>>::num_rows(self, input).unwrap();
let chunk_size = std::cmp::max((nb_rows + 1) / num_cpus::get(), 1);
unsafe {
let padding_start = nb_rows * NUM_SHIFT_RIGHT_COLS;
let padding_size = (padded_nb_rows - nb_rows) * NUM_SHIFT_RIGHT_COLS;
if padding_size > 0 {
core::ptr::write_bytes(buffer[padding_start..].as_mut_ptr(), 0, padding_size);
}
}
let buffer_ptr = buffer.as_mut_ptr() as *mut F;
let values = unsafe {
core::slice::from_raw_parts_mut(buffer_ptr, padded_nb_rows * NUM_SHIFT_RIGHT_COLS)
};
values.chunks_mut(chunk_size * NUM_SHIFT_RIGHT_COLS).enumerate().par_bridge().for_each(
|(i, rows)| {
rows.chunks_mut(NUM_SHIFT_RIGHT_COLS).enumerate().for_each(|(j, row)| {
let idx = i * chunk_size + j;
let cols: &mut ShiftRightCols<F> = row.borrow_mut();
if idx < nb_rows {
let mut byte_lookup_events = Vec::new();
let event = &input.shift_right_events[idx];
cols.adapter.populate(&mut byte_lookup_events, event.1);
self.event_to_row(&event.0, cols, &mut byte_lookup_events);
cols.state.populate(&mut byte_lookup_events, event.0.clk, event.0.pc);
cols.is_w_imm = F::from_bool(
(event.0.opcode == Opcode::SRLW || event.0.opcode == Opcode::SRAW)
&& event.1.is_imm,
);
} else {
cols.v_01 = F::from_canonical_u32(16);
cols.v_012 = F::from_canonical_u32(256);
cols.v_0123 = F::from_canonical_u32(65536);
}
});
},
);
}
fn generate_dependencies(&self, input: &Self::Record, output: &mut Self::Record) {
let chunk_size = std::cmp::max(input.shift_right_events.len() / num_cpus::get(), 1);
let blu_batches = input
.shift_right_events
.par_chunks(chunk_size)
.map(|events| {
let mut blu: HashMap<ByteLookupEvent, usize> = HashMap::new();
events.iter().for_each(|event| {
let mut row = [F::zero(); NUM_SHIFT_RIGHT_COLS];
let cols: &mut ShiftRightCols<F> = row.as_mut_slice().borrow_mut();
cols.adapter.populate(&mut blu, event.1);
self.event_to_row(&event.0, cols, &mut blu);
cols.state.populate(&mut blu, event.0.clk, event.0.pc);
});
blu
})
.collect::<Vec<_>>();
output.add_byte_lookup_events_from_maps(blu_batches.iter().collect_vec());
}
fn included(&self, shard: &Self::Record) -> bool {
if let Some(shape) = shard.shape.as_ref() {
shape.included::<F, _>(self)
} else {
!shard.shift_right_events.is_empty()
}
}
}
impl ShiftRightChip {
fn event_to_row<F: PrimeField>(
&self,
event: &AluEvent,
cols: &mut ShiftRightCols<F>,
blu: &mut impl ByteRecord,
) {
let mut b = u64_to_u16_limbs(event.b);
let c = u64_to_u16_limbs(event.c)[0];
cols.a = Word::from(event.a);
cols.is_srl = F::from_bool(event.opcode == Opcode::SRL);
cols.is_sra = F::from_bool(event.opcode == Opcode::SRA);
cols.is_srlw = F::from_bool(event.opcode == Opcode::SRLW);
cols.is_sraw = F::from_bool(event.opcode == Opcode::SRAW);
for i in 0..6 {
cols.c_bits[i] = F::from_canonical_u16((c >> i) & 1);
}
blu.add_bit_range_check(c >> 6, 10);
cols.v_01 = F::from_canonical_u32(1 << (4 - (c & 3)));
cols.v_012 = F::from_canonical_u32(1 << (8 - (c & 7)));
cols.v_0123 = F::from_canonical_u32(1 << (16 - (c & 15)));
if event.opcode == Opcode::SRA {
cols.b_msb.populate_msb(blu, b[3]);
} else if event.opcode == Opcode::SRAW {
cols.b_msb.populate_msb(blu, b[1]);
} else {
cols.b_msb.msb = F::zero();
}
cols.sra_msb_v0123 = cols.b_msb.msb * cols.v_0123;
let is_word = event.opcode == Opcode::SRLW || event.opcode == Opcode::SRAW;
let not_word = !is_word;
if is_word {
b[2] = 0;
b[3] = 0;
cols.srw_msb.populate_msb(blu, u64_to_u16_limbs(event.a)[1]);
} else {
cols.srw_msb.msb = F::zero();
}
let bit_shift = (c & 0xF) as u8;
for i in 0..WORD_SIZE {
let limb = b[i] as u32;
let lower_limb = (limb & ((1 << bit_shift) - 1)) as u16;
let higher_limb = (limb >> bit_shift) as u16;
cols.lower_limb[i] = F::from_canonical_u16(lower_limb);
cols.higher_limb[i] = F::from_canonical_u16(higher_limb);
blu.add_bit_range_check(lower_limb, bit_shift);
blu.add_bit_range_check(higher_limb, 16 - bit_shift);
}
for i in 0..WORD_SIZE {
cols.limb_result[i] = cols.higher_limb[i];
if i != WORD_SIZE - 1 {
cols.limb_result[i] +=
cols.lower_limb[i + 1] * F::from_canonical_u32(1 << (16 - bit_shift));
}
}
let shift_amount = ((c >> 4) & 1) + 2 * ((c >> 5) & 1) * (not_word as u16);
let mut shift = [0u16; 4];
for i in 0..4 {
if i == shift_amount as usize {
shift[i] = 1;
}
}
cols.shift_u16 = shift.map(|x| F::from_canonical_u16(x));
}
}
impl<F> BaseAir<F> for ShiftRightChip {
fn width(&self) -> usize {
NUM_SHIFT_RIGHT_COLS
}
}
impl<AB> Air<AB> for ShiftRightChip
where
AB: SP1CoreAirBuilder,
{
fn eval(&self, builder: &mut AB) {
let main = builder.main();
let local = main.row_slice(0);
let local: &ShiftRightCols<AB::Var> = (*local).borrow();
let is_real = local.is_srl + local.is_sra + local.is_srlw + local.is_sraw;
builder.assert_bool(local.is_srl);
builder.assert_bool(local.is_sra);
builder.assert_bool(local.is_srlw);
builder.assert_bool(local.is_sraw);
builder.assert_bool(is_real.clone());
let one = AB::Expr::one();
let is_word = local.is_srlw + local.is_sraw;
let not_word = local.is_srl + local.is_sra;
let opcode = local.is_srl * AB::F::from_canonical_u32(Opcode::SRL as u32)
+ local.is_sra * AB::F::from_canonical_u32(Opcode::SRA as u32)
+ local.is_srlw * AB::F::from_canonical_u32(Opcode::SRLW as u32)
+ local.is_sraw * AB::F::from_canonical_u32(Opcode::SRAW as u32);
let funct3 = local.is_srl * AB::Expr::from_canonical_u8(Opcode::SRL.funct3().unwrap())
+ local.is_sra * AB::Expr::from_canonical_u8(Opcode::SRA.funct3().unwrap())
+ local.is_srlw * AB::Expr::from_canonical_u8(Opcode::SRLW.funct3().unwrap())
+ local.is_sraw * AB::Expr::from_canonical_u8(Opcode::SRAW.funct3().unwrap());
let funct7 = local.is_srl * AB::Expr::from_canonical_u8(Opcode::SRL.funct7().unwrap_or(0))
+ local.is_sra * AB::Expr::from_canonical_u8(Opcode::SRA.funct7().unwrap())
+ local.is_srlw * AB::Expr::from_canonical_u8(Opcode::SRLW.funct7().unwrap_or(0))
+ local.is_sraw * AB::Expr::from_canonical_u8(Opcode::SRAW.funct7().unwrap());
let (srl_base, srl_imm) = Opcode::SRL.base_opcode();
let srl_imm = srl_imm.expect("SRL immediate opcode not found");
let (sra_base, sra_imm) = Opcode::SRA.base_opcode();
let sra_imm = sra_imm.expect("SRA immediate opcode not found");
let (srlw_base, srlw_imm) = Opcode::SRLW.base_opcode();
let srlw_imm = srlw_imm.expect("SRLW immediate opcode not found");
let (sraw_base, sraw_imm) = Opcode::SRAW.base_opcode();
let sraw_imm = sraw_imm.expect("SRAW immediate opcode not found");
let imm_base_difference = srl_base.checked_sub(srl_imm).unwrap();
assert_eq!(imm_base_difference, sra_base.checked_sub(sra_imm).unwrap());
assert_eq!(imm_base_difference, srlw_base.checked_sub(srlw_imm).unwrap());
assert_eq!(imm_base_difference, sraw_base.checked_sub(sraw_imm).unwrap());
let srl_base_expr = AB::Expr::from_canonical_u32(srl_base);
let sra_base_expr = AB::Expr::from_canonical_u32(sra_base);
let srlw_base_expr = AB::Expr::from_canonical_u32(srlw_base);
let sraw_base_expr = AB::Expr::from_canonical_u32(sraw_base);
let calculated_base_opcode = local.is_srl * srl_base_expr
+ local.is_sra * sra_base_expr
+ local.is_srlw * srlw_base_expr
+ local.is_sraw * sraw_base_expr
- AB::Expr::from_canonical_u32(imm_base_difference) * local.adapter.imm_c;
let srl_instr_type = Opcode::SRL.instruction_type().0 as u32;
let srl_instr_type_imm =
Opcode::SRL.instruction_type().1.expect("SRL immediate instruction type not found")
as u32;
let sra_instr_type = Opcode::SRA.instruction_type().0 as u32;
let sra_instr_type_imm =
Opcode::SRA.instruction_type().1.expect("SRA immediate instruction type not found")
as u32;
let srlw_instr_type = Opcode::SRLW.instruction_type().0 as u32;
let srlw_instr_type_imm =
Opcode::SRLW.instruction_type().1.expect("SRLW immediate instruction type not found")
as u32;
let sraw_instr_type = Opcode::SRAW.instruction_type().0 as u32;
let sraw_instr_type_imm =
Opcode::SRAW.instruction_type().1.expect("SRAW immediate instruction type not found")
as u32;
let instr_type_difference = srl_instr_type.checked_sub(srl_instr_type_imm).unwrap();
let sra_instr_type_difference = sra_instr_type.checked_sub(sra_instr_type_imm).unwrap();
let srlw_instr_type_difference = srlw_instr_type.checked_sub(srlw_instr_type_imm).unwrap();
let sraw_instr_type_difference = sraw_instr_type.checked_sub(sraw_instr_type_imm).unwrap();
let w_instr_imm_adjustment = srl_instr_type_imm.checked_sub(srlw_instr_type_imm).unwrap();
assert_eq!(instr_type_difference, sra_instr_type_difference);
assert_eq!(srlw_instr_type_difference, instr_type_difference + w_instr_imm_adjustment);
assert_eq!(sraw_instr_type_difference, instr_type_difference + w_instr_imm_adjustment);
builder.assert_eq(local.is_w_imm, (local.is_srlw + local.is_sraw) * local.adapter.imm_c);
let calculated_instr_type = local.is_srl * AB::Expr::from_canonical_u32(srl_instr_type)
+ local.is_sra * AB::Expr::from_canonical_u32(sra_instr_type)
+ local.is_srlw * AB::Expr::from_canonical_u32(srlw_instr_type)
+ local.is_sraw * AB::Expr::from_canonical_u32(sraw_instr_type)
- (AB::Expr::from_canonical_u32(instr_type_difference) * local.adapter.imm_c
+ AB::Expr::from_canonical_u32(w_instr_imm_adjustment) * local.is_w_imm);
for i in 0..6 {
builder.assert_bool(local.c_bits[i]);
}
let mut c_lower_bits = AB::Expr::zero();
let mut bit_shift = AB::Expr::zero();
for i in 0..6 {
c_lower_bits = c_lower_bits + local.c_bits[i] * AB::F::from_canonical_u32(1 << i);
if i == 3 {
bit_shift = c_lower_bits.clone();
}
}
let inverse_64 = AB::F::from_canonical_u32(64).inverse();
builder.send_byte(
AB::F::from_canonical_u32(ByteOpcode::Range as u32),
(local.adapter.c()[0] - c_lower_bits) * inverse_64,
AB::Expr::from_canonical_u32(10),
AB::Expr::zero(),
is_real.clone(),
);
for i in 0..WORD_SIZE {
builder.when(local.shift_u16[i]).assert_eq(
local.c_bits[4] + local.c_bits[5] * AB::F::from_canonical_u32(2) * not_word.clone(),
AB::Expr::from_canonical_u32(i as u32),
);
builder.assert_bool(local.shift_u16[i]);
}
builder.when(is_real.clone()).assert_eq(
local.shift_u16[0] + local.shift_u16[1] + local.shift_u16[2] + local.shift_u16[3],
AB::Expr::from_canonical_u32(1),
);
let two = AB::F::from_canonical_u32(2);
let three = AB::F::from_canonical_u32(3);
let fifteen = AB::F::from_canonical_u32(15);
let two_fifty_five = AB::F::from_canonical_u32(255);
builder.assert_eq(
local.v_01,
(((one.clone() - local.c_bits[0]) + one.clone()) * two)
* ((one.clone() - local.c_bits[1]) * three + one.clone()),
);
builder.assert_eq(
local.v_012,
local.v_01 * ((one.clone() - local.c_bits[2]) * fifteen + one.clone()),
);
builder.assert_eq(
local.v_0123,
local.v_012 * ((one.clone() - local.c_bits[3]) * two_fifty_five + one.clone()),
);
for i in 0..WORD_SIZE {
let limb = local.adapter.b()[i];
builder.send_byte(
AB::F::from_canonical_u32(ByteOpcode::Range as u32),
local.lower_limb[i],
bit_shift.clone(),
AB::Expr::zero(),
is_real.clone(),
);
builder.send_byte(
AB::F::from_canonical_u32(ByteOpcode::Range as u32),
local.higher_limb[i],
AB::Expr::from_canonical_u32(16) - bit_shift.clone(),
AB::Expr::zero(),
is_real.clone(),
);
if i < WORD_SIZE / 2 {
builder.assert_eq(
limb * local.v_0123,
local.higher_limb[i] * AB::Expr::from_canonical_u32(1 << 16)
+ local.lower_limb[i] * local.v_0123,
);
} else {
builder.assert_eq(
limb * local.v_0123 * not_word.clone(),
local.higher_limb[i] * AB::Expr::from_canonical_u32(1 << 16)
+ local.lower_limb[i] * local.v_0123,
);
}
}
for i in 0..WORD_SIZE {
let mut limb_result = local.higher_limb[i].into();
if i != WORD_SIZE - 1 {
limb_result = limb_result.clone() + local.lower_limb[i + 1] * local.v_0123;
}
builder.assert_eq(local.limb_result[i], limb_result);
}
<U16MSBOperation<AB::F> as SP1Operation<AB>>::eval(
builder,
U16MSBOperationInput::<AB>::new(
local.adapter.b().0[3].into(),
local.b_msb,
local.is_sra.into(),
),
);
<U16MSBOperation<AB::F> as SP1Operation<AB>>::eval(
builder,
U16MSBOperationInput::<AB>::new(
local.adapter.b().0[1].into(),
local.b_msb,
local.is_sraw.into(),
),
);
builder.when(local.is_srl + local.is_srlw).assert_zero(local.b_msb.msb);
builder.assert_eq(local.sra_msb_v0123, local.b_msb.msb * local.v_0123);
<U16MSBOperation<AB::F> as SP1Operation<AB>>::eval(
builder,
U16MSBOperationInput::<AB>::new(local.a.0[1].into(), local.srw_msb, is_word.clone()),
);
builder.when_not(is_word.clone()).assert_zero(local.srw_msb.msb);
let base = AB::F::from_canonical_u32(1 << 16);
let base_minus_one = AB::F::from_canonical_u16(u16::MAX);
for i in 0..WORD_SIZE {
for j in 0..(WORD_SIZE - 1 - i) {
builder
.when(not_word.clone())
.when(local.shift_u16[i])
.assert_eq(local.a[j], local.limb_result[i + j]);
}
builder.when(not_word.clone()).when(local.shift_u16[i]).assert_eq(
local.a[WORD_SIZE - 1 - i],
local.limb_result[WORD_SIZE - 1] + (local.b_msb.msb * base - local.sra_msb_v0123),
);
for j in (WORD_SIZE - i)..WORD_SIZE {
builder
.when(not_word.clone())
.when(local.shift_u16[i])
.assert_eq(local.a[j], local.b_msb.msb * base_minus_one);
}
}
builder
.when(is_word.clone())
.when(local.shift_u16[0])
.assert_eq(local.a[0], local.limb_result[0]);
builder.when(is_word.clone()).when(local.shift_u16[0]).assert_eq(
local.a[1],
local.limb_result[1] + (local.b_msb.msb * base - local.sra_msb_v0123),
);
builder.when(is_word.clone()).when(local.shift_u16[1]).assert_eq(
local.a[0],
local.limb_result[1] + (local.b_msb.msb * base - local.sra_msb_v0123),
);
builder
.when(is_word.clone())
.when(local.shift_u16[1])
.assert_eq(local.a[1], local.b_msb.msb * base_minus_one);
for i in WORD_SIZE / 2..WORD_SIZE {
builder.when(is_word.clone()).assert_eq(local.a[i], local.srw_msb.msb * base_minus_one);
}
<CPUState<AB::F> as SP1Operation<AB>>::eval(
builder,
CPUStateInput {
cols: local.state,
next_pc: [
local.state.pc[0] + AB::F::from_canonical_u32(PC_INC),
local.state.pc[1].into(),
local.state.pc[2].into(),
],
clk_increment: AB::Expr::from_canonical_u32(CLK_INC),
is_real: is_real.clone(),
},
);
let alu_reader_input = ALUTypeReaderInput::<AB, AB::Expr>::new(
local.state.clk_high::<AB>(),
local.state.clk_low::<AB>(),
local.state.pc,
opcode,
[calculated_instr_type, calculated_base_opcode, funct3, funct7],
local.a.map(|x| x.into()),
local.adapter,
is_real,
);
ALUTypeReader::<AB::F>::eval(builder, alu_reader_input);
}
}