use alloy::primitives::{I256, U256, U512};
use tycho_common::simulation::errors::SimulationError;
pub fn safe_mul_u256(a: U256, b: U256) -> Result<U256, SimulationError> {
let res = a.checked_mul(b);
_construc_result_u256(res)
}
pub fn safe_div_u256(a: U256, b: U256) -> Result<U256, SimulationError> {
if b.is_zero() {
return Err(SimulationError::FatalError("Division by zero".to_string()));
}
let res = a.checked_div(b);
_construc_result_u256(res)
}
pub fn safe_add_u256(a: U256, b: U256) -> Result<U256, SimulationError> {
let res = a.checked_add(b);
_construc_result_u256(res)
}
pub fn safe_sub_u256(a: U256, b: U256) -> Result<U256, SimulationError> {
let res = a.checked_sub(b);
_construc_result_u256(res)
}
pub fn div_mod_u256(a: U256, b: U256) -> Result<(U256, U256), SimulationError> {
if b.is_zero() {
return Err(SimulationError::FatalError("Division by zero".to_string()));
}
let result = a / b;
let rest = a % b;
Ok((result, rest))
}
pub fn _construc_result_u256(res: Option<U256>) -> Result<U256, SimulationError> {
match res {
None => Err(SimulationError::FatalError("U256 arithmetic overflow".to_string())),
Some(value) => Ok(value),
}
}
pub fn safe_mul_u512(a: U512, b: U512) -> Result<U512, SimulationError> {
let res = a.checked_mul(b);
_construc_result_u512(res)
}
pub fn safe_div_u512(a: U512, b: U512) -> Result<U512, SimulationError> {
if b.is_zero() {
return Err(SimulationError::FatalError("Division by zero".to_string()));
}
let res = a.checked_div(b);
_construc_result_u512(res)
}
pub fn safe_add_u512(a: U512, b: U512) -> Result<U512, SimulationError> {
let res = a.checked_add(b);
_construc_result_u512(res)
}
pub fn safe_sub_u512(a: U512, b: U512) -> Result<U512, SimulationError> {
let res = a.checked_sub(b);
_construc_result_u512(res)
}
pub fn div_mod_u512(a: U512, b: U512) -> Result<(U512, U512), SimulationError> {
if b.is_zero() {
return Err(SimulationError::FatalError("Division by zero".to_string()));
}
let result = a / b;
let rest = a % b;
Ok((result, rest))
}
pub fn _construc_result_u512(res: Option<U512>) -> Result<U512, SimulationError> {
match res {
None => Err(SimulationError::FatalError("U512 arithmetic overflow".to_string())),
Some(value) => Ok(value),
}
}
pub fn safe_mul_i256(a: I256, b: I256) -> Result<I256, SimulationError> {
let res = a.checked_mul(b);
_construc_result_i256(res)
}
pub fn safe_div_i256(a: I256, b: I256) -> Result<I256, SimulationError> {
if b.is_zero() {
return Err(SimulationError::FatalError("Division by zero".to_string()));
}
let res = a.checked_div(b);
_construc_result_i256(res)
}
pub fn safe_add_i256(a: I256, b: I256) -> Result<I256, SimulationError> {
let res = a.checked_add(b);
_construc_result_i256(res)
}
pub fn safe_sub_i256(a: I256, b: I256) -> Result<I256, SimulationError> {
let res = a.checked_sub(b);
_construc_result_i256(res)
}
pub fn _construc_result_i256(res: Option<I256>) -> Result<I256, SimulationError> {
match res {
None => Err(SimulationError::FatalError("I256 arithmetic overflow".to_string())),
Some(value) => Ok(value),
}
}
pub fn sqrt_u512(value: U512) -> U512 {
if value == U512::ZERO {
return U512::ZERO;
}
if value == U512::from(1u32) {
return U512::from(1u32);
}
let bits = 512 - value.leading_zeros();
let mut result = U512::from(1u32) << (bits / 2);
let mut decreasing = false;
loop {
let division = value / result;
let iter = (division + result) / U512::from(2u32);
if iter == result {
break;
}
if iter > result {
if decreasing {
break;
}
result =
if iter > result * U512::from(2u32) { result * U512::from(2u32) } else { iter };
} else {
decreasing = true;
result = iter;
}
}
result
}
pub fn sqrt_u256(value: U256) -> Result<U256, SimulationError> {
if value == U256::ZERO {
return Ok(U256::ZERO);
}
let bits = 256 - value.leading_zeros();
let mut remainder = U256::ZERO;
let mut temp = U256::ZERO;
let result = compute_karatsuba_sqrt(value, &mut remainder, &mut temp, bits);
let limbs = result.as_limbs();
Ok(U256::from_limbs([limbs[0], limbs[1], limbs[2], limbs[3]]))
}
fn compute_karatsuba_sqrt(x: U256, r: &mut U256, t: &mut U256, bits: usize) -> U256 {
if bits <= 64 {
let x_small = x.as_limbs()[0];
let result = (x_small as f64).sqrt() as u64;
*r = x - U256::from(result * result);
return U256::from(result);
}
let b = bits / 4;
let mut q = x >> (b * 2);
let mut s = compute_karatsuba_sqrt(q, r, t, bits - b * 2);
*t = (U256::from(1u32) << (b * 2)) - U256::from(1u32);
*r = (*r << b) | ((x & *t) >> b);
s <<= 1;
q = *r / s;
*r -= q * s;
s = (s << (b - 1)) + q;
*t = (U256::from(1u32) << b) - U256::from(1u32);
*r = (*r << b) | (x & *t);
let q_squared = q * q;
if *r < q_squared {
*t = (s << 1) - U256::from(1u32);
*r += *t;
s -= U256::from(1u32);
}
*r -= q_squared;
s
}
#[cfg(test)]
mod safe_math_tests {
use std::str::FromStr;
use rstest::rstest;
use super::*;
const U256_MAX: U256 = U256::from_limbs([u64::MAX, u64::MAX, u64::MAX, u64::MAX]);
const U512_MAX: U512 = U512::from_limbs([
u64::MAX,
u64::MAX,
u64::MAX,
u64::MAX,
u64::MAX,
u64::MAX,
u64::MAX,
u64::MAX,
]);
const I256_MAX: I256 = I256::from_raw(U256::from_limbs([
u64::MAX,
u64::MAX,
u64::MAX,
9223372036854775807u64, ]));
const I256_MIN: I256 = I256::from_raw(U256::from_limbs([
0,
0,
0,
9223372036854775808u64, ]));
fn u256(s: &str) -> U256 {
U256::from_str(s).unwrap()
}
#[rstest]
#[case(U256_MAX, u256("2"), true, false, u256("0"))]
#[case(u256("3"), u256("2"), false, true, u256("6"))]
fn test_safe_mul_u256(
#[case] a: U256,
#[case] b: U256,
#[case] is_err: bool,
#[case] is_ok: bool,
#[case] expected: U256,
) {
let res = safe_mul_u256(a, b);
assert_eq!(res.is_err(), is_err);
assert_eq!(res.is_ok(), is_ok);
if is_ok {
assert_eq!(res.unwrap(), expected);
}
}
#[rstest]
#[case(U256_MAX, u256("2"), true, false, u256("0"))]
#[case(u256("3"), u256("2"), false, true, u256("5"))]
fn test_safe_add_u256(
#[case] a: U256,
#[case] b: U256,
#[case] is_err: bool,
#[case] is_ok: bool,
#[case] expected: U256,
) {
let res = safe_add_u256(a, b);
assert_eq!(res.is_err(), is_err);
assert_eq!(res.is_ok(), is_ok);
if is_ok {
assert_eq!(res.unwrap(), expected);
}
}
#[rstest]
#[case(u256("0"), u256("2"), true, false, u256("0"))]
#[case(u256("10"), u256("2"), false, true, u256("8"))]
fn test_safe_sub_u256(
#[case] a: U256,
#[case] b: U256,
#[case] is_err: bool,
#[case] is_ok: bool,
#[case] expected: U256,
) {
let res = safe_sub_u256(a, b);
assert_eq!(res.is_err(), is_err);
assert_eq!(res.is_ok(), is_ok);
if is_ok {
assert_eq!(res.unwrap(), expected);
}
}
#[rstest]
#[case(u256("1"), u256("0"), true, false, u256("0"))]
#[case(u256("10"), u256("2"), false, true, u256("5"))]
fn test_safe_div_u256(
#[case] a: U256,
#[case] b: U256,
#[case] is_err: bool,
#[case] is_ok: bool,
#[case] expected: U256,
) {
let res = safe_div_u256(a, b);
assert_eq!(res.is_err(), is_err);
assert_eq!(res.is_ok(), is_ok);
if is_ok {
assert_eq!(res.unwrap(), expected);
}
}
fn u512(s: &str) -> U512 {
U512::from_str(s).unwrap()
}
#[rstest]
#[case(U512_MAX, u512("2"), true, false, u512("0"))]
#[case(u512("3"), u512("2"), false, true, u512("6"))]
fn test_safe_mul_u512(
#[case] a: U512,
#[case] b: U512,
#[case] is_err: bool,
#[case] is_ok: bool,
#[case] expected: U512,
) {
let res = safe_mul_u512(a, b);
assert_eq!(res.is_err(), is_err);
assert_eq!(res.is_ok(), is_ok);
if is_ok {
assert_eq!(res.unwrap(), expected);
}
}
#[rstest]
#[case(U512_MAX, u512("2"), true, false, u512("0"))]
#[case(u512("3"), u512("2"), false, true, u512("5"))]
fn test_safe_add_u512(
#[case] a: U512,
#[case] b: U512,
#[case] is_err: bool,
#[case] is_ok: bool,
#[case] expected: U512,
) {
let res = safe_add_u512(a, b);
assert_eq!(res.is_err(), is_err);
assert_eq!(res.is_ok(), is_ok);
if is_ok {
assert_eq!(res.unwrap(), expected);
}
}
#[rstest]
#[case(u512("0"), u512("2"), true, false, u512("0"))]
#[case(u512("10"), u512("2"), false, true, u512("8"))]
fn test_safe_sub_u512(
#[case] a: U512,
#[case] b: U512,
#[case] is_err: bool,
#[case] is_ok: bool,
#[case] expected: U512,
) {
let res = safe_sub_u512(a, b);
assert_eq!(res.is_err(), is_err);
assert_eq!(res.is_ok(), is_ok);
if is_ok {
assert_eq!(res.unwrap(), expected);
}
}
#[rstest]
#[case(u512("1"), u512("0"), true, false, u512("0"))]
#[case(u512("10"), u512("2"), false, true, u512("5"))]
fn test_safe_div_u512(
#[case] a: U512,
#[case] b: U512,
#[case] is_err: bool,
#[case] is_ok: bool,
#[case] expected: U512,
) {
let res = safe_div_u512(a, b);
assert_eq!(res.is_err(), is_err);
assert_eq!(res.is_ok(), is_ok);
if is_ok {
assert_eq!(res.unwrap(), expected);
}
}
fn i256(s: &str) -> I256 {
I256::from_str(s).unwrap()
}
#[rstest]
#[case(I256_MAX, i256("2"), true, false, i256("0"))]
#[case(i256("3"), i256("2"), false, true, i256("6"))]
fn test_safe_mul_i256(
#[case] a: I256,
#[case] b: I256,
#[case] is_err: bool,
#[case] is_ok: bool,
#[case] expected: I256,
) {
let res = safe_mul_i256(a, b);
assert_eq!(res.is_err(), is_err);
assert_eq!(res.is_ok(), is_ok);
if is_ok {
assert_eq!(res.unwrap(), expected);
}
}
#[rstest]
#[case(I256_MAX, i256("2"), true, false, i256("0"))]
#[case(i256("3"), i256("2"), false, true, i256("5"))]
fn test_safe_add_i256(
#[case] a: I256,
#[case] b: I256,
#[case] is_err: bool,
#[case] is_ok: bool,
#[case] expected: I256,
) {
let res = safe_add_i256(a, b);
assert_eq!(res.is_err(), is_err);
assert_eq!(res.is_ok(), is_ok);
if is_ok {
assert_eq!(res.unwrap(), expected);
}
}
#[rstest]
#[case(I256_MIN, i256("2"), true, false, i256("0"))]
#[case(i256("10"), i256("2"), false, true, i256("8"))]
fn test_safe_sub_i256(
#[case] a: I256,
#[case] b: I256,
#[case] is_err: bool,
#[case] is_ok: bool,
#[case] expected: I256,
) {
let res = safe_sub_i256(a, b);
assert_eq!(res.is_err(), is_err);
assert_eq!(res.is_ok(), is_ok);
if is_ok {
assert_eq!(res.unwrap(), expected);
}
}
#[rstest]
#[case(i256("1"), i256("0"), true, false, i256("0"))]
#[case(i256("10"), i256("2"), false, true, i256("5"))]
fn test_safe_div_i256(
#[case] a: I256,
#[case] b: I256,
#[case] is_err: bool,
#[case] is_ok: bool,
#[case] expected: I256,
) {
let res = safe_div_i256(a, b);
assert_eq!(res.is_err(), is_err);
assert_eq!(res.is_ok(), is_ok);
if is_ok {
assert_eq!(res.unwrap(), expected);
}
}
#[test]
fn test_sqrt_u512() {
assert_eq!(sqrt_u512(U512::ZERO), U512::ZERO);
assert_eq!(sqrt_u512(U512::from(1u32)), U512::from(1u32));
assert_eq!(sqrt_u512(U512::from(4u32)), U512::from(2u32));
assert_eq!(sqrt_u512(U512::from(100u32)), U512::from(10u32));
assert_eq!(sqrt_u512(U512::from(10000u32)), U512::from(100u32));
assert_eq!(sqrt_u512(U512::from(1000000u32)), U512::from(1000u32));
assert_eq!(sqrt_u512(U512::from(2u32)), U512::from(1u32)); assert_eq!(sqrt_u512(U512::from(3u32)), U512::from(1u32)); assert_eq!(sqrt_u512(U512::from(5u32)), U512::from(2u32)); assert_eq!(sqrt_u512(U512::from(8u32)), U512::from(2u32)); assert_eq!(sqrt_u512(U512::from(10u32)), U512::from(3u32)); assert_eq!(sqrt_u512(U512::from(15u32)), U512::from(3u32)); assert_eq!(sqrt_u512(U512::from(99u32)), U512::from(9u32));
let large = U512::from_str("1000000000000000000000000000000000000").unwrap();
let sqrt_large = sqrt_u512(large);
assert!(sqrt_large * sqrt_large <= large);
assert!((sqrt_large + U512::from(1u32)) * (sqrt_large + U512::from(1u32)) > large);
}
}