use slop_air::AirBuilder;
use slop_algebra::{AbstractField, Field};
use sp1_core_executor::{events::ByteRecord, ByteOpcode};
use sp1_derive::AlignedBorrow;
use sp1_hypercube::air::SP1AirBuilder;
use sp1_primitives::consts::u32_to_u16_limbs;
use crate::utils::u32_to_half_word;
#[derive(AlignedBorrow, Default, Debug, Clone, Copy)]
#[repr(C)]
pub struct FixedShiftRightOperation<T> {
pub value: [T; 2],
pub higher_limb: [T; 2],
}
impl<F: Field> FixedShiftRightOperation<F> {
pub const fn nb_limbs_to_shift(rotation: usize) -> usize {
rotation / 16
}
pub const fn nb_bits_to_shift(rotation: usize) -> usize {
rotation % 16
}
pub const fn carry_multiplier(rotation: usize) -> u32 {
let nb_bits_to_shift = Self::nb_bits_to_shift(rotation);
1 << (16 - nb_bits_to_shift)
}
pub fn populate(&mut self, record: &mut impl ByteRecord, input: u32, rotation: usize) -> u32 {
let input_limbs = u32_to_u16_limbs(input);
let expected = input >> rotation;
self.value = u32_to_half_word(expected);
let nb_limbs_to_shift = Self::nb_limbs_to_shift(rotation);
let nb_bits_to_shift = Self::nb_bits_to_shift(rotation);
let mut word = [0u16; 2];
for i in 0..2 {
if i + nb_limbs_to_shift < 2 {
word[i] = input_limbs[i + nb_limbs_to_shift];
}
}
for i in (0..2).rev() {
let limb = word[i];
let lower_limb = (limb & ((1 << nb_bits_to_shift) - 1)) as u16;
let higher_limb = (limb >> nb_bits_to_shift) as u16;
self.higher_limb[i] = F::from_canonical_u16(higher_limb);
record.add_bit_range_check(lower_limb, nb_bits_to_shift as u8);
record.add_bit_range_check(higher_limb, (16 - nb_bits_to_shift) as u8);
}
expected
}
pub fn eval<AB: SP1AirBuilder>(
builder: &mut AB,
input: [AB::Var; 2],
rotation: usize,
cols: FixedShiftRightOperation<AB::Var>,
is_real: AB::Var,
) {
builder.assert_bool(is_real);
let nb_limbs_to_shift = Self::nb_limbs_to_shift(rotation);
let nb_bits_to_shift = Self::nb_bits_to_shift(rotation);
let carry_multiplier = AB::F::from_canonical_u32(Self::carry_multiplier(rotation));
let input_limbs_shifted: [AB::Expr; 2] = std::array::from_fn(|i| {
if i + nb_limbs_to_shift < 2 {
input[i + nb_limbs_to_shift].into()
} else {
AB::Expr::zero()
}
});
let mut lower_limb = [AB::Expr::zero(), AB::Expr::zero()];
for i in 0..2 {
let limb = input_limbs_shifted[i].clone();
lower_limb[i] =
limb - cols.higher_limb[i] * AB::Expr::from_canonical_u32(1 << nb_bits_to_shift);
builder.send_byte(
AB::F::from_canonical_u32(ByteOpcode::Range as u32),
lower_limb[i].clone(),
AB::F::from_canonical_u32(nb_bits_to_shift as u32),
AB::Expr::zero(),
is_real,
);
builder.send_byte(
AB::F::from_canonical_u32(ByteOpcode::Range as u32),
cols.higher_limb[i],
AB::Expr::from_canonical_u32(16 - nb_bits_to_shift as u32),
AB::Expr::zero(),
is_real,
);
}
builder.when(is_real).assert_eq(cols.value[1], cols.higher_limb[1]);
builder.when(is_real).assert_eq(
cols.value[0],
cols.higher_limb[0] + lower_limb[1].clone() * carry_multiplier,
);
}
}