use num::BigUint;
use sp1_curves::{params::NumWords, weierstrass::FpOpField};
use sp1_jit::SyscallContext;
use sp1_primitives::consts::u64_to_u32;
use typenum::Unsigned;
use crate::events::FieldOperation;
pub(crate) unsafe fn fp2_addsub_syscall<P: FpOpField>(
ctx: &mut impl SyscallContext,
arg1: u64,
arg2: u64,
fp_op: FieldOperation,
) -> Option<u64> {
let x_ptr = arg1;
if !x_ptr.is_multiple_of(8) {
panic!("x_ptr must be 8-byte aligned");
}
let y_ptr = arg2;
if !y_ptr.is_multiple_of(8) {
panic!("y_ptr must be 8-byte aligned");
}
let num_words = <P as NumWords>::WordsCurvePoint::USIZE;
let x_32 = u64_to_u32(ctx.mr_slice_unsafe(x_ptr, num_words));
let y_32 = u64_to_u32(ctx.mr_slice(y_ptr, num_words));
let (ac0, ac1) = x_32.split_at(x_32.len() / 2);
let (bc0, bc1) = y_32.split_at(y_32.len() / 2);
let ac0 = &BigUint::from_slice(ac0);
let ac1 = &BigUint::from_slice(ac1);
let bc0 = &BigUint::from_slice(bc0);
let bc1 = &BigUint::from_slice(bc1);
let modulus = &BigUint::from_bytes_le(P::MODULUS);
let (c0, c1) = match fp_op {
FieldOperation::Add => ((ac0 + bc0) % modulus, (ac1 + bc1) % modulus),
FieldOperation::Sub => ((ac0 + modulus - bc0) % modulus, (ac1 + modulus - bc1) % modulus),
_ => panic!("Invalid operation"),
};
let mut result = c0.to_u64_digits();
result.resize(num_words / 2, 0);
result.append(&mut c1.to_u64_digits());
result.resize(num_words, 0);
ctx.bump_memory_clk();
ctx.mw_slice(x_ptr, &result);
None
}