Skip to main content

fluentbase_runtime/syscall_handler/uint256/
uint256_mul_mod.rs

1use crate::RuntimeContext;
2use num::{BigUint, One, Zero};
3use rwasm::{StoreTr, TrapCode, Value};
4
5pub fn syscall_uint256_mul_mod_handler(
6    caller: &mut impl StoreTr<RuntimeContext>,
7    params: &[Value],
8    _result: &mut [Value],
9) -> Result<(), TrapCode> {
10    let (x_ptr, y_ptr, m_ptr) = (
11        params[0].i32().unwrap() as usize,
12        params[1].i32().unwrap() as usize,
13        params[2].i32().unwrap() as usize,
14    );
15
16    let mut x = [0u8; 32];
17    caller.memory_read(x_ptr, &mut x)?;
18    let mut y = [0u8; 32];
19    caller.memory_read(y_ptr, &mut y)?;
20    let mut m = [0u8; 32];
21    caller.memory_read(m_ptr, &mut m)?;
22
23    let result_vec = syscall_uint256_mul_mod_impl(&x, &y, &m);
24    caller.memory_write(x_ptr, &result_vec)
25}
26
27pub fn syscall_uint256_mul_mod_impl(x: &[u8; 32], y: &[u8; 32], m: &[u8; 32]) -> [u8; 32] {
28    // Get the BigUint values for x, y, and the modulus.
29    let uint256_x = BigUint::from_bytes_le(x);
30    let uint256_y = BigUint::from_bytes_le(y);
31    let uint256_m = BigUint::from_bytes_le(m);
32
33    // Perform the multiplication and take the result modulo the modulus.
34    let result: BigUint = if uint256_m.is_zero() {
35        let modulus = BigUint::one() << 256;
36        (uint256_x * uint256_y) % modulus
37    } else {
38        (uint256_x * uint256_y) % uint256_m
39    };
40
41    let mut result_bytes = result.to_bytes_le();
42    result_bytes.resize(32, 0u8); // Pad the result to 32 bytes.
43    let mut result = [0u8; 32];
44    result.copy_from_slice(&result_bytes);
45    result
46}
47
48/// These tests are taken from: sp1/crates/test-artifacts/programs/uint256-mul/src/main.rs
49#[cfg(test)]
50mod tests {
51    use super::*;
52    use rand::Rng;
53
54    fn biguint_to_bytes_le(x: BigUint) -> [u8; 32] {
55        let mut bytes = x.to_bytes_le();
56        bytes.resize(32, 0);
57        bytes.try_into().unwrap()
58    }
59
60    #[test]
61    fn test_u256_mul_mod() {
62        for _ in 0..50 {
63            // Test with random numbers.
64            let mut rng = rand::rng();
65            let mut x: [u8; 32] = rng.random();
66            let mut y: [u8; 32] = rng.random();
67            let modulus: [u8; 32] = rng.random();
68
69            // Convert byte arrays to BigUint
70            let modulus_big = BigUint::from_bytes_le(&modulus);
71            let x_big = BigUint::from_bytes_le(&x);
72            x = biguint_to_bytes_le(&x_big % &modulus_big);
73            let y_big = BigUint::from_bytes_le(&y);
74            y = biguint_to_bytes_le(&y_big % &modulus_big);
75
76            let result_bytes = syscall_uint256_mul_mod_impl(&x, &y, &modulus);
77
78            let result = (x_big * y_big) % modulus_big;
79            let result_syscall = BigUint::from_bytes_le(&result_bytes);
80
81            assert_eq!(result, result_syscall);
82        }
83
84        // Modulus zero tests
85        let modulus = [0u8; 32];
86        let modulus_big: BigUint = BigUint::one() << 256;
87        for _ in 0..50 {
88            // Test with random numbers.
89            let mut rng = rand::rng();
90            let mut x: [u8; 32] = rng.random();
91            let mut y: [u8; 32] = rng.random();
92
93            // Convert byte arrays to BigUint
94            let x_big = BigUint::from_bytes_le(&x);
95            x = biguint_to_bytes_le(&x_big % &modulus_big);
96            let y_big = BigUint::from_bytes_le(&y);
97            y = biguint_to_bytes_le(&y_big % &modulus_big);
98
99            let result_bytes = syscall_uint256_mul_mod_impl(&x, &y, &modulus);
100
101            let result = (x_big * y_big) % &modulus_big;
102            let result_syscall = BigUint::from_bytes_le(&result_bytes);
103
104            assert_eq!(result, result_syscall, "x: {:?}, y: {:?}", x, y);
105        }
106
107        // Test with random numbers.
108        let mut rng = rand::rng();
109        let x: [u8; 32] = rng.random();
110
111        // Hardcoded edge case: Multiplying by 1
112        let modulus = [0u8; 32];
113
114        let mut one: [u8; 32] = [0; 32];
115        one[0] = 1; // Least significant byte set to 1, represents the number 1
116        let original_x = x; // Copy original x value before multiplication by 1
117        let result_one = syscall_uint256_mul_mod_impl(&x, &one, &modulus);
118        assert_eq!(
119            result_one, original_x,
120            "Multiplying by 1 should yield the same number."
121        );
122
123        // Hardcoded edge case: Multiplying by 0
124        let zero: [u8; 32] = [0; 32]; // Represents the number 0
125        let result_zero = syscall_uint256_mul_mod_impl(&x, &zero, &modulus);
126        assert_eq!(result_zero, zero, "Multiplying by 0 should yield 0.");
127    }
128}