use triton_vm::prelude::*;
use crate::arithmetic::u192::overflowing_add::OverflowingAdd;
use crate::prelude::*;
#[derive(Debug, Clone)]
pub struct SafeAdd;
impl SafeAdd {
pub(crate) const OVERFLOW_ERROR_ID: i128 = 600;
}
impl BasicSnippet for SafeAdd {
fn parameters(&self) -> Vec<(DataType, String)> {
vec![
(DataType::U192, "l".to_owned()),
(DataType::U192, "r".to_owned()),
]
}
fn return_values(&self) -> Vec<(DataType, String)> {
vec![(DataType::U192, "sum".to_owned())]
}
fn entrypoint(&self) -> String {
"tasmlib_arithmetic_u192_safe_add".to_string()
}
fn code(&self, _: &mut Library) -> Vec<LabelledInstruction> {
let add_code = OverflowingAdd::addition_code();
triton_asm! {
{self.entrypoint()}:
{&add_code}
push 0
eq
assert error_id {Self::OVERFLOW_ERROR_ID}
return
}
}
}
#[cfg(test)]
mod tests {
use num::BigUint;
use rand::rngs::StdRng;
use super::*;
use crate::arithmetic::u192::U192;
use crate::arithmetic::u192::to_u192;
use crate::arithmetic::u192::u128_to_u192_shl64;
use crate::test_prelude::*;
#[test]
fn rust_shadow() {
ShadowedClosure::new(SafeAdd).test()
}
#[test]
fn overflow_test() {
for (left, right) in [
(1 << 127, 1 << 127),
(u128::MAX, u128::MAX),
(u128::MAX, 1),
(u128::MAX, 1 << 31),
(u128::MAX, 1 << 32),
(u128::MAX, 1 << 33),
(u128::MAX, 1 << 63),
(u128::MAX, 1 << 64),
(u128::MAX, 1 << 65),
(u128::MAX, 1 << 95),
(u128::MAX, 1 << 96),
(u128::MAX, 1 << 97),
(u128::MAX - 1, 2),
]
.into_iter()
.flat_map(|(left, right)| [(left, right), (right, left)])
{
let left = to_u192(left, u64::MAX);
let right = u128_to_u192_shl64(right);
test_assertion_failure(
&ShadowedClosure::new(SafeAdd),
InitVmState::with_stack(SafeAdd.set_up_test_stack((left, right))),
&[SafeAdd::OVERFLOW_ERROR_ID],
);
}
for i in 0..128 {
let left = 1 << i;
let right = u128::MAX - left + 1;
assert_eq!(
(0, true),
left.overflowing_add(right),
"i = {i}. a = {left}, b = {right}"
);
let left = to_u192(left, u64::MAX);
let right = u128_to_u192_shl64(right);
test_assertion_failure(
&ShadowedClosure::new(SafeAdd),
InitVmState::with_stack(SafeAdd.set_up_test_stack((left, right))),
&[SafeAdd::OVERFLOW_ERROR_ID],
);
}
}
impl Closure for SafeAdd {
type Args = <OverflowingAdd as Closure>::Args;
fn rust_shadow(&self, stack: &mut Vec<BFieldElement>) {
let left: U192 = pop_encodable(stack);
let left: BigUint = BigUint::new(left.to_vec());
let right: U192 = pop_encodable(stack);
let right: BigUint = BigUint::new(right.to_vec());
let sum = left + right;
let mut sum = sum.to_u32_digits();
assert!(sum.len() <= 6, "Overflow");
sum.resize(6, 0);
let sum: U192 = sum.try_into().unwrap();
push_encodable(stack, &sum);
}
fn pseudorandom_args(&self, seed: [u8; 32], _: Option<BenchmarkCase>) -> Self::Args {
let mut rng = StdRng::from_seed(seed);
let lhs: U192 = rng.random();
let lhs_as_biguint: BigUint = BigUint::new(lhs.to_vec());
let u192_max = BigUint::from_bytes_be(&[0xFF; 24]);
let max = &u192_max - &lhs_as_biguint;
let mut rhs_bytes = [0u8; 24];
let rhs = loop {
rng.fill(&mut rhs_bytes);
let candidate = BigUint::from_bytes_be(&rhs_bytes);
if candidate < max {
break candidate;
}
};
let mut rhs = rhs.to_u32_digits();
rhs.resize(6, 0);
(lhs, rhs.try_into().unwrap())
}
fn corner_case_args(&self) -> Vec<Self::Args> {
fn u192_checked_add(l: U192, r: U192) -> Option<U192> {
let l: BigUint = BigUint::new(l.to_vec());
let r: BigUint = BigUint::new(r.to_vec());
let sum = l + r;
let mut sum = sum.to_u32_digits();
if sum.len() > 6 {
None
} else {
sum.resize(6, 0);
Some(sum.try_into().unwrap())
}
}
let edge_case_points = OverflowingAdd::edge_case_points();
edge_case_points
.iter()
.cartesian_product(&edge_case_points)
.filter(|&(&l, &r)| u192_checked_add(l, r).is_some())
.map(|(&l, &r)| (l, r))
.collect()
}
}
}
#[cfg(test)]
mod benches {
use super::*;
use crate::test_prelude::*;
#[test]
fn benchmark() {
ShadowedClosure::new(SafeAdd).bench()
}
}