use num::BigUint;
use crate::{events::Uint256Operation, SyscallCode};
use sp1_jit::{
RiscRegister::{X12, X13, X14, X5},
SyscallContext,
};
const U256_NUM_WORDS: usize = 4;
pub unsafe fn uint256_ops(ctx: &mut impl SyscallContext, arg1: u64, arg2: u64) -> Option<u64> {
let syscall_id = ctx.rr(X5);
let syscall_code = SyscallCode::from_u32(syscall_id as u32);
let op = syscall_code.uint256_op_map();
let a_ptr = arg1;
let b_ptr = arg2;
let c_ptr = ctx.rr(X12);
let d_ptr = ctx.rr(X13);
let e_ptr = ctx.rr(X14);
let uint256_a = {
let a = ctx.mr_slice(a_ptr, U256_NUM_WORDS);
BigUint::from_slice(
&a.into_iter().flat_map(|&x| [x as u32, (x >> 32) as u32]).collect::<Vec<_>>(),
)
};
ctx.bump_memory_clk();
let uint256_b = {
let b = ctx.mr_slice(b_ptr, U256_NUM_WORDS);
BigUint::from_slice(
&b.into_iter().flat_map(|&x| [x as u32, (x >> 32) as u32]).collect::<Vec<_>>(),
)
};
ctx.bump_memory_clk();
let uint256_c = {
let c = ctx.mr_slice(c_ptr, U256_NUM_WORDS);
BigUint::from_slice(
&c.into_iter().flat_map(|&x| [x as u32, (x >> 32) as u32]).collect::<Vec<_>>(),
)
};
let intermediate_result = match op {
Uint256Operation::Add => uint256_a + uint256_b + uint256_c,
Uint256Operation::Mul => uint256_a * uint256_b + uint256_c,
};
let mut u64_result = intermediate_result.to_u64_digits();
u64_result.resize(8, 0);
ctx.bump_memory_clk();
ctx.mw_slice(d_ptr, &u64_result[0..4]);
ctx.bump_memory_clk();
ctx.mw_slice(e_ptr, &u64_result[4..8]);
None
}