Skip to main content

fluentbase_runtime/syscall_handler/uint256/
uint256_mul_x2048.rs

1use crate::RuntimeContext;
2use num::{BigUint, Integer, One};
3use rwasm::{StoreTr, TrapCode, Value};
4
5const U256_NUM_BYTES: usize = 32;
6const U2048_NUM_BYTES: usize = 256;
7
8pub fn syscall_uint256_x2048_mul_handler(
9    ctx: &mut impl StoreTr<RuntimeContext>,
10    params: &[Value],
11    _result: &mut [Value],
12) -> Result<(), TrapCode> {
13    let (a_ptr, b_ptr, lo_ptr, hi_ptr) = (
14        params[0].i32().unwrap() as usize,
15        params[1].i32().unwrap() as usize,
16        params[2].i32().unwrap() as usize,
17        params[3].i32().unwrap() as usize,
18    );
19
20    let mut a = [0u8; U256_NUM_BYTES];
21    ctx.memory_read(a_ptr, &mut a)?;
22    let mut b = [0u8; U2048_NUM_BYTES];
23    ctx.memory_read(b_ptr, &mut b)?;
24
25    let (lo_bytes, hi_bytes) = syscall_uint256_x2048_mul_impl(&a, &b);
26
27    ctx.memory_write(lo_ptr, &lo_bytes)?;
28    ctx.memory_write(hi_ptr, &hi_bytes)?;
29    Ok(())
30}
31
32pub fn syscall_uint256_x2048_mul_impl(
33    a: &[u8; U256_NUM_BYTES],
34    b: &[u8; U2048_NUM_BYTES],
35) -> ([u8; U2048_NUM_BYTES], [u8; U256_NUM_BYTES]) {
36    let uint256_a = BigUint::from_bytes_le(a);
37    let uint2048_b = BigUint::from_bytes_le(b);
38    let result = uint256_a * uint2048_b;
39    let two_to_2048 = BigUint::one() << 2048;
40    let (hi, lo) = result.div_rem(&two_to_2048);
41    let mut lo_bytes = lo.to_bytes_le();
42    lo_bytes.resize(U2048_NUM_BYTES, 0u8);
43    let lo_res: [u8; U2048_NUM_BYTES] = lo_bytes.try_into().unwrap();
44    let mut hi_bytes = hi.to_bytes_le();
45    hi_bytes.resize(U256_NUM_BYTES, 0u8);
46    let hi_res: [u8; U256_NUM_BYTES] = hi_bytes.try_into().unwrap();
47    (lo_res, hi_res)
48}
49
50/// These tests are taken from: sp1/crates/test-artifacts/programs/u256x2048-mul/src/main.rs
51#[cfg(test)]
52mod tests {
53    use super::*;
54    use rand::Rng;
55
56    fn u256_to_bytes_le(x: &BigUint) -> [u8; 32] {
57        let mut bytes = x.to_bytes_le();
58        bytes.resize(32, 0);
59        bytes.try_into().unwrap()
60    }
61
62    fn u2048_to_bytes_le(x: &BigUint) -> [u8; 256] {
63        let mut bytes = x.to_bytes_le();
64        bytes.resize(256, 0);
65        bytes.try_into().unwrap()
66    }
67
68    #[test]
69    fn test_uin256_x2048_mul_sp1() {
70        let mut a_max: [u8; 32] = [0xff; 32];
71        let mut b_max: [u8; 256] = [0xff; 256];
72
73        let a_max_big = BigUint::from_bytes_le(&a_max);
74        a_max = u256_to_bytes_le(&a_max_big);
75        let b_max_big = BigUint::from_bytes_le(&b_max);
76        b_max = u2048_to_bytes_le(&b_max_big);
77
78        let (lo_max_bytes, hi_max_bytes) = syscall_uint256_x2048_mul_impl(&a_max, &b_max);
79
80        let lo_max_big = BigUint::from_bytes_le(&lo_max_bytes);
81        let hi_max_big = BigUint::from_bytes_le(&hi_max_bytes);
82
83        let result_max_syscall = (hi_max_big << 2048) + lo_max_big;
84        let result_max = a_max_big * b_max_big;
85        assert_eq!(result_max, result_max_syscall);
86
87        // Test 10 random pairs of a and b.
88        let mut rng = rand::rng();
89        for _ in 0..10 {
90            let a: [u8; 32] = rng.random();
91            let mut b = [0u8; 256];
92            rng.fill(&mut b);
93
94            let a_big = BigUint::from_bytes_le(&a);
95            let b_big = BigUint::from_bytes_le(&b);
96
97            let a = u256_to_bytes_le(&a_big);
98            let b = u2048_to_bytes_le(&b_big);
99
100            let (lo_bytes, hi_bytes) = syscall_uint256_x2048_mul_impl(&a, &b);
101
102            let lo_big = BigUint::from_bytes_le(&lo_bytes);
103            let hi_big = BigUint::from_bytes_le(&hi_bytes);
104
105            let result_syscall = (hi_big << 2048) + lo_big;
106            let result = a_big * b_big;
107            assert_eq!(result, result_syscall);
108        }
109    }
110}