Skip to main content

fluentbase_runtime/syscall_handler/tower/
tower_fp1_add_sub_mul.rs

1use crate::{syscall_handler::syscall_process_exit_code, RuntimeContext};
2use fluentbase_types::{ExitCode, BLS12381_FP_SIZE, BN254_FP_SIZE};
3use num::BigUint;
4use rwasm::{StoreTr, TrapCode, Value};
5use sp1_curves::weierstrass::{bls12_381::Bls12381BaseField, bn254::Bn254BaseField, FpOpField};
6
7pub fn syscall_tower_fp1_bn254_add_handler(
8    ctx: &mut impl StoreTr<RuntimeContext>,
9    params: &[Value],
10    _result: &mut [Value],
11) -> Result<(), TrapCode> {
12    syscall_tower_fp1_add_sub_mul_handler::<BN254_FP_SIZE, Bn254BaseField, FP_FIELD_ADD>(
13        ctx, params, _result,
14    )
15}
16pub fn syscall_tower_fp1_bn254_sub_handler(
17    ctx: &mut impl StoreTr<RuntimeContext>,
18    params: &[Value],
19    _result: &mut [Value],
20) -> Result<(), TrapCode> {
21    syscall_tower_fp1_add_sub_mul_handler::<BN254_FP_SIZE, Bn254BaseField, FP_FIELD_SUB>(
22        ctx, params, _result,
23    )
24}
25pub fn syscall_tower_fp1_bn254_mul_handler(
26    ctx: &mut impl StoreTr<RuntimeContext>,
27    params: &[Value],
28    _result: &mut [Value],
29) -> Result<(), TrapCode> {
30    syscall_tower_fp1_add_sub_mul_handler::<BN254_FP_SIZE, Bn254BaseField, FP_FIELD_MUL>(
31        ctx, params, _result,
32    )
33}
34pub fn syscall_tower_fp1_bls12381_add_handler(
35    ctx: &mut impl StoreTr<RuntimeContext>,
36    params: &[Value],
37    _result: &mut [Value],
38) -> Result<(), TrapCode> {
39    syscall_tower_fp1_add_sub_mul_handler::<BLS12381_FP_SIZE, Bls12381BaseField, FP_FIELD_ADD>(
40        ctx, params, _result,
41    )
42}
43pub fn syscall_tower_fp1_bls12381_sub_handler(
44    ctx: &mut impl StoreTr<RuntimeContext>,
45    params: &[Value],
46    _result: &mut [Value],
47) -> Result<(), TrapCode> {
48    syscall_tower_fp1_add_sub_mul_handler::<BLS12381_FP_SIZE, Bls12381BaseField, FP_FIELD_SUB>(
49        ctx, params, _result,
50    )
51}
52pub fn syscall_tower_fp1_bls12381_mul_handler(
53    ctx: &mut impl StoreTr<RuntimeContext>,
54    params: &[Value],
55    _result: &mut [Value],
56) -> Result<(), TrapCode> {
57    syscall_tower_fp1_add_sub_mul_handler::<BLS12381_FP_SIZE, Bls12381BaseField, FP_FIELD_MUL>(
58        ctx, params, _result,
59    )
60}
61
62const FP_FIELD_ADD: u32 = 0x01;
63const FP_FIELD_SUB: u32 = 0x02;
64const FP_FIELD_MUL: u32 = 0x03;
65
66pub(crate) fn syscall_tower_fp1_add_sub_mul_handler<
67    const NUM_BYTES: usize,
68    P: FpOpField,
69    const FIELD_OP: u32,
70>(
71    ctx: &mut impl StoreTr<RuntimeContext>,
72    params: &[Value],
73    _result: &mut [Value],
74) -> Result<(), TrapCode> {
75    let (x_ptr, y_ptr) = (
76        params[0].i32().unwrap() as u32,
77        params[1].i32().unwrap() as u32,
78    );
79    let mut x = [0u8; NUM_BYTES];
80    ctx.memory_read(x_ptr as usize, &mut x)?;
81    let mut y = [0u8; NUM_BYTES];
82    ctx.memory_read(y_ptr as usize, &mut y)?;
83
84    let result = syscall_tower_fp1_add_sub_mul_impl::<NUM_BYTES, P, FIELD_OP>(
85        x,
86        y,
87    )
88    .map_err(|exit_code| syscall_process_exit_code(ctx, exit_code))?;
89
90    ctx.memory_write(x_ptr as usize, &result)?;
91    Ok(())
92}
93
94pub fn syscall_tower_fp1_bn254_add_impl(
95    x: [u8; BN254_FP_SIZE],
96    y: [u8; BN254_FP_SIZE],
97) -> Result<[u8; BN254_FP_SIZE], ExitCode> {
98    syscall_tower_fp1_add_sub_mul_impl::<BN254_FP_SIZE, Bn254BaseField, FP_FIELD_ADD>(x, y)
99}
100pub fn syscall_tower_fp1_bn254_sub_impl(
101    x: [u8; BN254_FP_SIZE],
102    y: [u8; BN254_FP_SIZE],
103) -> Result<[u8; BN254_FP_SIZE], ExitCode> {
104    syscall_tower_fp1_add_sub_mul_impl::<BN254_FP_SIZE, Bn254BaseField, FP_FIELD_SUB>(x, y)
105}
106pub fn syscall_tower_fp1_bn254_mul_impl(
107    x: [u8; BN254_FP_SIZE],
108    y: [u8; BN254_FP_SIZE],
109) -> Result<[u8; BN254_FP_SIZE], ExitCode> {
110    syscall_tower_fp1_add_sub_mul_impl::<BN254_FP_SIZE, Bn254BaseField, FP_FIELD_MUL>(x, y)
111}
112pub fn syscall_tower_fp1_bls12381_add_impl(
113    x: [u8; BLS12381_FP_SIZE],
114    y: [u8; BLS12381_FP_SIZE],
115) -> Result<[u8; BLS12381_FP_SIZE], ExitCode> {
116    syscall_tower_fp1_add_sub_mul_impl::<BLS12381_FP_SIZE, Bls12381BaseField, FP_FIELD_ADD>(x, y)
117}
118pub fn syscall_tower_fp1_bls12381_sub_impl(
119    x: [u8; BLS12381_FP_SIZE],
120    y: [u8; BLS12381_FP_SIZE],
121) -> Result<[u8; BLS12381_FP_SIZE], ExitCode> {
122    syscall_tower_fp1_add_sub_mul_impl::<BLS12381_FP_SIZE, Bls12381BaseField, FP_FIELD_SUB>(x, y)
123}
124pub fn syscall_tower_fp1_bls12381_mul_impl(
125    x: [u8; BLS12381_FP_SIZE],
126    y: [u8; BLS12381_FP_SIZE],
127) -> Result<[u8; BLS12381_FP_SIZE], ExitCode> {
128    syscall_tower_fp1_add_sub_mul_impl::<BLS12381_FP_SIZE, Bls12381BaseField, FP_FIELD_MUL>(x, y)
129}
130
131pub(crate) fn syscall_tower_fp1_add_sub_mul_impl<
132    const NUM_BYTES: usize,
133    P: FpOpField,
134    const FIELD_OP: u32,
135>(
136    x: [u8; NUM_BYTES],
137    y: [u8; NUM_BYTES],
138) -> Result<[u8; NUM_BYTES], ExitCode> {
139    let modulus = &BigUint::from_bytes_le(P::MODULUS);
140    let a = BigUint::from_bytes_le(&x);
141    let b = BigUint::from_bytes_le(&y);
142    let result = match FIELD_OP {
143        FP_FIELD_ADD => (a + b) % modulus,
144        FP_FIELD_SUB => {
145            if &a + modulus < b {
146                return Err(ExitCode::MalformedBuiltinParams);
147            }
148            ((a + modulus) - b) % modulus
149        }
150        FP_FIELD_MUL => (a * b) % modulus,
151        _ => unreachable!(),
152    };
153    let mut result = result.to_bytes_le();
154    result.resize(NUM_BYTES, 0);
155    let result: [u8; NUM_BYTES] = result.try_into().expect("length checked");
156    Ok(result)
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162    use rand::Rng;
163    use std::str::FromStr;
164
165    fn random_bigint<const NUM_BYTES: usize>(modulus: &BigUint) -> BigUint {
166        let mut rng = rand::rng();
167        let mut arr = vec![0u8; NUM_BYTES];
168        for item in arr.iter_mut() {
169            *item = rng.random();
170        }
171        BigUint::from_bytes_le(&arr) % modulus
172    }
173
174    fn big_uint_into_bytes<const NUM_BYTES: usize>(a: &BigUint) -> [u8; NUM_BYTES] {
175        let mut res = a.to_bytes_le();
176        res.resize(NUM_BYTES, 0);
177        res.try_into().expect("length checked")
178    }
179
180    /// Tests are stolen from: sp1/crates/test-artifacts/programs/bn254-fp/src/main.rs
181    #[test]
182    fn test_bn254_fp1() {
183        let modulus = BigUint::from_str(
184            "21888242871839275222246405745257275088696311157297823662689037894645226208583",
185        )
186        .unwrap();
187        let zero = BigUint::ZERO;
188        let one = BigUint::from(1u32);
189
190        let add = |a: &BigUint, b: &BigUint| -> BigUint {
191            let result =
192                syscall_tower_fp1_bn254_add_impl(big_uint_into_bytes(a), big_uint_into_bytes(b))
193                    .unwrap();
194            BigUint::from_bytes_le(&result)
195        };
196        let sub = |a: &BigUint, b: &BigUint| -> BigUint {
197            let result =
198                syscall_tower_fp1_bn254_sub_impl(big_uint_into_bytes(a), big_uint_into_bytes(b))
199                    .unwrap();
200            BigUint::from_bytes_le(&result)
201        };
202        let mul = |a: &BigUint, b: &BigUint| -> BigUint {
203            let result =
204                syscall_tower_fp1_bn254_mul_impl(big_uint_into_bytes(a), big_uint_into_bytes(b))
205                    .unwrap();
206            BigUint::from_bytes_le(&result)
207        };
208
209        for _ in 0..10 {
210            let a = random_bigint::<32>(&modulus);
211            let b = random_bigint::<32>(&modulus);
212
213            // Test addition
214            let result = add(&a, &b) % &modulus;
215            assert_eq!((&a + &b) % &modulus, result);
216
217            // Test addition with zero
218            let result = add(&a, &zero) % &modulus;
219            assert_eq!((&a + &zero) % &modulus, result);
220
221            // Test subtraction
222            let expected_sub = if a < b {
223                ((&a + &modulus) - &b) % &modulus
224            } else {
225                (&a - &b) % &modulus
226            };
227            let result = sub(&a, &b) % &modulus;
228            assert_eq!(expected_sub, result);
229
230            // Test subtraction with zero
231            let result = sub(&a, &zero) % &modulus;
232            assert_eq!((&a + &modulus - &zero) % &modulus, result);
233
234            // Test multiplication
235            let result = mul(&a, &b) % &modulus;
236            assert_eq!((&a * &b) % &modulus, result);
237
238            // Test multiplication with one
239            let result = mul(&a, &one) % &modulus;
240            assert_eq!((&a * &one) % &modulus, result);
241
242            // Test multiplication with zero
243            let result = mul(&a, &zero) % &modulus;
244            assert_eq!((&a * &zero) % &modulus, result);
245        }
246    }
247
248    /// Tests are stolen from: sp1/crates/test-artifacts/programs/bls12381-fp/src/main.rs
249    #[test]
250    fn test_bls12381_fp1() {
251        let modulus = BigUint::from_str("4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559787").unwrap();
252        let zero = BigUint::ZERO;
253        let one = BigUint::from(1u32);
254
255        let add = |a: &BigUint, b: &BigUint| -> BigUint {
256            let result =
257                syscall_tower_fp1_bls12381_add_impl(big_uint_into_bytes(a), big_uint_into_bytes(b))
258                    .unwrap();
259            BigUint::from_bytes_le(&result)
260        };
261        let sub = |a: &BigUint, b: &BigUint| -> BigUint {
262            let result =
263                syscall_tower_fp1_bls12381_sub_impl(big_uint_into_bytes(a), big_uint_into_bytes(b))
264                    .unwrap();
265            BigUint::from_bytes_le(&result)
266        };
267        let mul = |a: &BigUint, b: &BigUint| -> BigUint {
268            let result =
269                syscall_tower_fp1_bls12381_mul_impl(big_uint_into_bytes(a), big_uint_into_bytes(b))
270                    .unwrap();
271            BigUint::from_bytes_le(&result)
272        };
273
274        for _ in 0..10 {
275            let a = random_bigint::<48>(&modulus);
276            let b = random_bigint::<48>(&modulus);
277
278            // Test addition
279            let result = add(&a, &b) % &modulus;
280            assert_eq!((&a + &b) % &modulus, result);
281
282            // Test addition with zero
283            let result = add(&a, &zero) % &modulus;
284            assert_eq!((&a + &zero) % &modulus, result);
285
286            // Test subtraction
287            let expected_sub = if a < b {
288                ((&a + &modulus) - &b) % &modulus
289            } else {
290                (&a - &b) % &modulus
291            };
292            let result = sub(&a, &b) % &modulus;
293            assert_eq!(expected_sub, result);
294
295            // Test subtraction with zero
296            let result = sub(&a, &zero) % &modulus;
297            assert_eq!((&a + &modulus - &zero) % &modulus, result);
298
299            // Test multiplication
300            let result = mul(&a, &b) % &modulus;
301            assert_eq!((&a * &b) % &modulus, result);
302
303            // Test multiplication with one
304            let result = mul(&a, &one) % &modulus;
305            assert_eq!((&a * &one) % &modulus, result);
306
307            // Test multiplication with zero
308            let result = &mul(&a, &zero) % &modulus;
309            assert_eq!((&a * &zero) % &modulus, result,);
310        }
311    }
312
313    #[test]
314    #[should_panic(
315        expected = "called `Result::unwrap()` on an `Err` value: MalformedBuiltinParams"
316    )]
317    fn test_bls12381_overflow_cant_happen() {
318        let m = BigUint::from_str("4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559787").unwrap();
319        let a = BigUint::from_str("3001807166416250545063342369301928117417662114954255913999043602093023737868128398332015721846761748028420704419841").unwrap();
320        let b = BigUint::from_str("12007228665665002180253369477207712469670648459817023655996174408372094951472513593328062887387046992113682817679361").unwrap();
321        assert!(a < m && b > m);
322
323        let result =
324            syscall_tower_fp1_bls12381_sub_impl(big_uint_into_bytes(&a), big_uint_into_bytes(&b))
325                .unwrap();
326        let result = BigUint::from_bytes_le(&result);
327
328        let expected_sub = ((&a + &m) - &b % &m) % &m;
329        assert_eq!(expected_sub, result);
330    }
331}