use triton_vm::prelude::*;
use crate::arithmetic;
use crate::arithmetic::u64::mul_two_u64s_to_u128::MulTwoU64sToU128;
use crate::prelude::*;
#[derive(Debug, Clone)]
pub struct SafeMul;
impl SafeMul {
pub(crate) const OVERFLOW_0: i128 = 580;
pub(crate) const OVERFLOW_1: i128 = 581;
pub(crate) const OVERFLOW_2: i128 = 582;
pub(crate) const OVERFLOW_3: i128 = 583;
pub(crate) const OVERFLOW_4: i128 = 584;
}
impl BasicSnippet for SafeMul {
fn parameters(&self) -> Vec<(DataType, String)> {
["right", "left"]
.map(|side| (DataType::U160, side.to_string()))
.to_vec()
}
fn return_values(&self) -> Vec<(DataType, String)> {
vec![(DataType::U160, "product".to_string())]
}
fn entrypoint(&self) -> String {
"tasmlib_arithmetic_u160_safe_mul".to_string()
}
fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
let u64_to_u128_mul = library.import(Box::new(MulTwoU64sToU128));
let u64_safe_mul = library.import(Box::new(arithmetic::u64::safe_mul::SafeMul));
let u64_safe_add = library.import(Box::new(arithmetic::u64::add::Add));
let u128_safe_add = library.import(Box::new(arithmetic::u128::safe_add::SafeAdd));
let u160_safe_add = library.import(Box::new(arithmetic::u160::safe_add::SafeAdd));
triton_asm!(
{self.entrypoint()}:
push 0
place 10
push 0
place 5
dup 9
push 0
eq
dup 9
push 0
eq
mul
dup 5
push 0
eq
dup 12
push 0
eq
dup 6
push 0
eq
dup 6
push 0
eq
mul
dup 2
dup 2
add
pop_count
assert error_id {Self::OVERFLOW_0}
add
pop_count
assert error_id {Self::OVERFLOW_1}
add
pop_count
assert error_id {Self::OVERFLOW_2}
pick 11
pick 11
dup 3
dup 3
call {u64_safe_mul}
dup 11
dup 11
dup 7
dup 7
call {u64_safe_mul}
dup 11
dup 11
pick 11
pick 11
call {u64_safe_mul}
call {u64_safe_add}
call {u64_safe_add}
pick 1
push 0
eq
assert error_id {Self::OVERFLOW_3}
push 0
push 0
push 0
push 0
pick 12
pick 12
dup 8
dup 8
call {u64_to_u128_mul}
dup 14
dup 14
pick 14
pick 14
call {u64_to_u128_mul}
call {u128_safe_add}
pick 3
push 0
eq
assert error_id {Self::OVERFLOW_4}
push 0
push 0
push 0
pick 14
pick 14
pick 14
pick 14
call {u64_to_u128_mul}
call {u160_safe_add}
call {u160_safe_add}
return
)
}
}
#[cfg(test)]
mod tests {
use num::BigUint;
use num::One;
use rand::rngs::StdRng;
use super::*;
use crate::arithmetic::u160::u128_to_u160;
use crate::arithmetic::u160::u128_to_u160_shl_32;
use crate::arithmetic::u160::u128_to_u160_shl_32_lower_limb_filled;
use crate::test_prelude::*;
impl SafeMul {
fn test_assertion_failure(&self, left: [u32; 5], right: [u32; 5], error_ids: &[i128]) {
test_assertion_failure(
&ShadowedClosure::new(Self),
InitVmState::with_stack(self.set_up_test_stack((right, left))),
error_ids,
);
}
}
#[test]
fn rust_shadow() {
ShadowedClosure::new(SafeMul).test()
}
#[test]
fn overflow_unit_test() {
SafeMul.test_assertion_failure(
u128_to_u160_shl_32(u128::MAX),
u128_to_u160_shl_32(u128::MAX),
&[580],
);
SafeMul.test_assertion_failure(
u128_to_u160_shl_32(1u128 << 64),
u128_to_u160_shl_32(u128::MAX),
&[581],
);
SafeMul.test_assertion_failure(
u128_to_u160_shl_32(u128::MAX),
u128_to_u160_shl_32(1u128 << 64),
&[582],
);
SafeMul.test_assertion_failure(
u128_to_u160(1u128 << 64),
u128_to_u160(1u128 << 96),
&[583],
);
SafeMul.test_assertion_failure(
u128_to_u160(1u128 << 96),
u128_to_u160(1u128 << 64),
&[583],
);
SafeMul.test_assertion_failure(
u128_to_u160((1u128 << 64) - 1),
u128_to_u160(1u128 << 99),
&[584],
);
SafeMul.test_assertion_failure(
u128_to_u160(1u128 << 99),
u128_to_u160((1u128 << 64) - 1),
&[584],
);
SafeMul.test_assertion_failure(u128_to_u160(2), u128_to_u160_shl_32(1 << 127), &[583]);
SafeMul.test_assertion_failure(u128_to_u160_shl_32(1 << 127), u128_to_u160(2), &[583]);
}
#[proptest(cases = 100)]
fn arbitrary_overflow_crashes_vm_u128(
#[strategy(2_u128..)] left: u128,
#[strategy(u128::MAX / #left + 1..)] right: u128,
) {
let left = u128_to_u160_shl_32(left);
let right = u128_to_u160(right);
SafeMul.test_assertion_failure(left, right, &[580, 581, 582, 583, 584, 570]);
}
#[proptest(cases = 50)]
fn marginal_overflow_crashes_vm(
#[strategy(2_u8..128)] _log_upper_bound: u8,
#[strategy(2_u128..(1 << #_log_upper_bound))] left: u128,
) {
let right = u128::MAX / left + 1;
let expected_error_codes = [580, 581, 582, 583, 584, 100, 101, 102, 103, 570];
SafeMul.test_assertion_failure(
u128_to_u160_shl_32(left),
u128_to_u160(right),
&expected_error_codes,
);
SafeMul.test_assertion_failure(
u128_to_u160(left),
u128_to_u160_shl_32(right),
&expected_error_codes,
);
}
#[proptest(cases = 50)]
fn arbitrary_overflow_crashes_vm(
#[strategy(2_u8..128)] _log_upper_bound: u8,
#[strategy(2_u128..(1 << #_log_upper_bound))] left: u128,
#[strategy(u128::MAX / #left + 1..)] right: u128,
) {
let expected_error_codes = [580, 581, 582, 583, 584, 100, 101, 102, 103, 570];
SafeMul.test_assertion_failure(
u128_to_u160_shl_32(left),
u128_to_u160(right),
&expected_error_codes,
);
SafeMul.test_assertion_failure(
u128_to_u160(left),
u128_to_u160_shl_32(right),
&expected_error_codes,
);
SafeMul.test_assertion_failure(
u128_to_u160_shl_32_lower_limb_filled(left),
u128_to_u160(right),
&expected_error_codes,
);
SafeMul.test_assertion_failure(
u128_to_u160(left),
u128_to_u160_shl_32_lower_limb_filled(right),
&expected_error_codes,
);
SafeMul.test_assertion_failure(
u128_to_u160_shl_32(left),
u128_to_u160_shl_32(right),
&expected_error_codes,
);
SafeMul.test_assertion_failure(
u128_to_u160_shl_32(left),
u128_to_u160_shl_32(right),
&expected_error_codes,
);
}
impl Closure for SafeMul {
type Args = ([u32; 5], [u32; 5]);
fn rust_shadow(&self, stack: &mut Vec<BFieldElement>) {
let left: [u32; 5] = pop_encodable(stack);
let left: BigUint = BigUint::new(left.to_vec());
let right: [u32; 5] = pop_encodable(stack);
let right: BigUint = BigUint::new(right.to_vec());
let prod = left.clone() * right.clone();
let mut prod = prod.to_u32_digits();
assert!(prod.len() <= 5, "Overflow: left: {left}, right: {right}.");
prod.resize(5, 0);
let prod: [u32; 5] = prod.try_into().unwrap();
push_encodable(stack, &prod);
}
fn pseudorandom_args(&self, seed: [u8; 32], _: Option<BenchmarkCase>) -> Self::Args {
let mut rng = StdRng::from_seed(seed);
let lhs: [u32; 5] = rng.random();
let lhs_as_biguint = BigUint::new(lhs.to_vec());
let u160_max = BigUint::from_bytes_be(&[0xFF; 20]);
let max = &u160_max / &lhs_as_biguint;
let bits: u32 = max.bits().try_into().unwrap();
let bit_mask = BigUint::from(2u32).pow(bits) - BigUint::one();
let mut bit_mask = bit_mask.to_bytes_be();
bit_mask.reverse();
bit_mask.resize(20, 0);
bit_mask.reverse();
let mut rhs_bytes = [0u8; 20];
let rhs = loop {
rng.fill(&mut rhs_bytes);
for i in 0..20 {
rhs_bytes[i] &= bit_mask[i];
}
let candidate = BigUint::from_bytes_be(&rhs_bytes);
if candidate < max {
break candidate;
}
};
{
let prod = lhs_as_biguint * rhs.clone();
assert!(prod.to_u32_digits().len() <= 5);
}
let mut rhs = rhs.to_u32_digits();
rhs.resize(5, 0);
(lhs, rhs.try_into().unwrap())
}
fn corner_case_args(&self) -> Vec<Self::Args> {
fn u160_checked_mul(l: [u32; 5], r: [u32; 5]) -> Option<[u32; 5]> {
let l: BigUint = BigUint::new(l.to_vec());
let r: BigUint = BigUint::new(r.to_vec());
let prod = l * r;
let mut prod = prod.to_u32_digits();
if prod.len() > 5 {
None
} else {
prod.resize(5, 0);
Some(prod.try_into().unwrap())
}
}
let edge_case_points = vec![
u128_to_u160(0),
u128_to_u160(1),
u128_to_u160(2),
u128_to_u160(u8::MAX as u128),
u128_to_u160(1 << 8),
u128_to_u160(u16::MAX as u128),
u128_to_u160(1 << 16),
u128_to_u160(u32::MAX as u128),
u128_to_u160(1 << 32),
u128_to_u160(u64::MAX as u128),
u128_to_u160(1 << 64),
[u32::MAX, u32::MAX, u32::MAX, 0, 0],
u128_to_u160(1 << 96),
u128_to_u160(u128::MAX),
[u32::MAX, u32::MAX, u32::MAX, u32::MAX, u32::MAX >> 1],
[u32::MAX; 5],
];
edge_case_points
.iter()
.cartesian_product(&edge_case_points)
.filter(|&(&l, &r)| u160_checked_mul(l, r).is_some())
.map(|(&l, &r)| (l, r))
.collect()
}
}
}
#[cfg(test)]
mod benches {
use super::*;
use crate::test_prelude::*;
#[test]
fn benchmark() {
ShadowedClosure::new(SafeMul).bench()
}
}