use triton_vm::prelude::*;
use crate::arithmetic::u32::is_u32::IsU32;
use crate::arithmetic::u32::shift_left::ShiftLeft as ShlU32;
use crate::arithmetic::u32::shift_right::ShiftRight as ShrU32;
use crate::prelude::*;
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
pub struct ShiftRight;
impl ShiftRight {
pub const ARGUMENT_LIMB_3_NOT_U32_ERROR_ID: i128 = 323;
pub const ARGUMENT_LIMB_2_NOT_U32_ERROR_ID: i128 = 322;
pub const ARGUMENT_LIMB_1_NOT_U32_ERROR_ID: i128 = 321;
pub const ARGUMENT_LIMB_0_NOT_U32_ERROR_ID: i128 = 320;
pub const SHAMT_NOT_U32_ERROR_ID: i128 = 324;
}
impl BasicSnippet for ShiftRight {
fn parameters(&self) -> Vec<(DataType, String)> {
vec![
(DataType::I128, "arg".to_string()),
(DataType::U32, "shamt".to_string()),
]
}
fn return_values(&self) -> Vec<(DataType, String)> {
vec![(DataType::I128, "res".to_string())]
}
fn entrypoint(&self) -> String {
"tasmlib_arithmetic_i128_shift_right".to_string()
}
fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
let entrypoint = self.entrypoint();
let shr_i128_by_32n = format!("{entrypoint}_by_32n");
let clean_up_for_early_return = format!("{entrypoint}_early_return");
let entrypoint = self.entrypoint();
let is_u32 = library.import(Box::new(IsU32));
let shr_u32 = library.import(Box::new(ShrU32));
let shl_u32 = library.import(Box::new(ShlU32));
triton_asm! {
{entrypoint}:
dup 4 dup 4 dup 4 dup 4
push 128 dup 5
lt
assert error_id {Self::SHAMT_NOT_U32_ERROR_ID}
call {is_u32} assert error_id {Self::ARGUMENT_LIMB_0_NOT_U32_ERROR_ID}
call {is_u32} assert error_id {Self::ARGUMENT_LIMB_1_NOT_U32_ERROR_ID}
call {is_u32} assert error_id {Self::ARGUMENT_LIMB_2_NOT_U32_ERROR_ID}
call {is_u32} assert error_id {Self::ARGUMENT_LIMB_3_NOT_U32_ERROR_ID}
dup 4 push 31 call {shr_u32}
hint msb = stack[0]
call {shr_i128_by_32n}
dup 1 push 0 eq dup 0
skiz call {clean_up_for_early_return}
skiz return
push 32 dup 2 push -1 mul add
push {u32::MAX} dup 2 mul
dup 1 call {shl_u32}
pick 7 dup 0
dup 3 call {shl_u32}
place 2
dup 5 call {shr_u32}
add
swap 7 dup 0
dup 3 call {shl_u32}
place 2
dup 5 call {shr_u32}
add
swap 6 dup 0
dup 3 call {shl_u32}
place 2
dup 5 call {shr_u32}
add
swap 5
pick 4
call {shr_u32}
add
place 2 pop 2
return
{shr_i128_by_32n}:
push 32 dup 2 lt
skiz return
push {u32::MAX} dup 1 mul
place 6
pick 2 pop 1
pick 1 addi -32 place 1
recurse
{clean_up_for_early_return}:
place 2
pop 2
return
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_helpers::tasm_final_state;
use crate::test_prelude::*;
impl ShiftRight {
fn assert_expected_shift_behavior(&self, arg: i128, shamt: u32) {
let initial_stack = self.set_up_test_stack((arg, shamt));
let mut expected_stack = initial_stack.clone();
self.rust_shadow(&mut expected_stack);
test_rust_equivalence_given_complete_state(
&ShadowedClosure::new(Self),
&initial_stack,
&[],
&NonDeterminism::default(),
&None,
Some(&expected_stack),
);
}
}
impl Closure for ShiftRight {
type Args = (i128, u32);
fn rust_shadow(&self, stack: &mut Vec<BFieldElement>) {
let (arg, shift_amount) = pop_encodable::<Self::Args>(stack);
push_encodable(stack, &(arg >> shift_amount));
}
fn pseudorandom_args(&self, seed: [u8; 32], _: Option<BenchmarkCase>) -> Self::Args {
let mut rng = StdRng::from_seed(seed);
(rng.random(), rng.random_range(0..128))
}
}
#[test]
fn standard_test() {
ShadowedClosure::new(ShiftRight).test()
}
#[proptest]
fn proptest(#[strategy(arb())] arg: i128, #[strategy(0u32..128)] shamt: u32) {
ShiftRight.assert_expected_shift_behavior(arg, shamt);
}
#[test]
fn test_edge_cases() {
let arguments = (0..4)
.map(|_| [-1, 0, 1])
.multi_cartesian_product()
.map(|limbs| <[i128; 4]>::try_from(limbs).unwrap())
.map(|[l0, l1, l2, l3]| l0 + (l1 << 32) + (l2 << 64) + (l3 << 96));
let shift_amounts = [0, 1, 16, 31]
.into_iter()
.cartesian_product(0..4)
.map(|(l, r)| l + 32 * r);
arguments
.cartesian_product(shift_amounts)
.for_each(|(arg, shamt)| ShiftRight.assert_expected_shift_behavior(arg, shamt));
}
#[proptest(cases = 50)]
fn shifting_right_by_127_is_zero_or_minus_1(arg: i128) {
let mut final_state = tasm_final_state(
&ShadowedClosure::new(ShiftRight),
&ShiftRight.set_up_test_stack((arg, 127)),
&[],
NonDeterminism::default(),
&None,
);
let final_stack = &mut final_state.op_stack.stack;
let num_bits_in_result = pop_encodable::<i128>(final_stack).count_ones();
if arg.is_positive() {
prop_assert_eq!(0, num_bits_in_result);
} else {
prop_assert_eq!(i128::BITS, num_bits_in_result);
}
}
}
#[cfg(test)]
mod benches {
use super::*;
use crate::test_prelude::*;
#[test]
fn benchmark() {
ShadowedClosure::new(ShiftRight).bench()
}
}