Skip to main content

fluentbase_runtime/syscall_handler/tower/
tower_fp2_add_sub_mul.rs

1use crate::{syscall_handler::syscall_process_exit_code, RuntimeContext};
2use fluentbase_types::{ExitCode, BLS12381_FP_SIZE, BN254_FP_SIZE};
3use num::BigUint;
4use rwasm::{StoreTr, TrapCode, Value};
5use sp1_curves::weierstrass::{bls12_381::Bls12381BaseField, bn254::Bn254BaseField, FpOpField};
6
7pub fn syscall_tower_fp2_bn254_add_handler(
8    ctx: &mut impl StoreTr<RuntimeContext>,
9    params: &[Value],
10    _result: &mut [Value],
11) -> Result<(), TrapCode> {
12    syscall_tower_fp2_add_sub_mul_handler::<BN254_FP_SIZE, Bn254BaseField, FP_FIELD_ADD>(
13        ctx, params, _result,
14    )
15}
16pub fn syscall_tower_fp2_bn254_sub_handler(
17    ctx: &mut impl StoreTr<RuntimeContext>,
18    params: &[Value],
19    _result: &mut [Value],
20) -> Result<(), TrapCode> {
21    syscall_tower_fp2_add_sub_mul_handler::<BN254_FP_SIZE, Bn254BaseField, FP_FIELD_SUB>(
22        ctx, params, _result,
23    )
24}
25pub fn syscall_tower_fp2_bn254_mul_handler(
26    ctx: &mut impl StoreTr<RuntimeContext>,
27    params: &[Value],
28    _result: &mut [Value],
29) -> Result<(), TrapCode> {
30    syscall_tower_fp2_add_sub_mul_handler::<BN254_FP_SIZE, Bn254BaseField, FP_FIELD_MUL>(
31        ctx, params, _result,
32    )
33}
34pub fn syscall_tower_fp2_bls12381_add_handler(
35    ctx: &mut impl StoreTr<RuntimeContext>,
36    params: &[Value],
37    _result: &mut [Value],
38) -> Result<(), TrapCode> {
39    syscall_tower_fp2_add_sub_mul_handler::<BLS12381_FP_SIZE, Bls12381BaseField, FP_FIELD_ADD>(
40        ctx, params, _result,
41    )
42}
43pub fn syscall_tower_fp2_bls12381_sub_handler(
44    ctx: &mut impl StoreTr<RuntimeContext>,
45    params: &[Value],
46    _result: &mut [Value],
47) -> Result<(), TrapCode> {
48    syscall_tower_fp2_add_sub_mul_handler::<BLS12381_FP_SIZE, Bls12381BaseField, FP_FIELD_SUB>(
49        ctx, params, _result,
50    )
51}
52pub fn syscall_tower_fp2_bls12381_mul_handler(
53    ctx: &mut impl StoreTr<RuntimeContext>,
54    params: &[Value],
55    _result: &mut [Value],
56) -> Result<(), TrapCode> {
57    syscall_tower_fp2_add_sub_mul_handler::<BLS12381_FP_SIZE, Bls12381BaseField, FP_FIELD_MUL>(
58        ctx, params, _result,
59    )
60}
61
62const FP_FIELD_ADD: u32 = 0x01;
63const FP_FIELD_SUB: u32 = 0x02;
64const FP_FIELD_MUL: u32 = 0x03;
65
66pub(crate) fn syscall_tower_fp2_add_sub_mul_handler<
67    const NUM_BYTES: usize,
68    P: FpOpField,
69    const FIELD_OP: u32,
70>(
71    ctx: &mut impl StoreTr<RuntimeContext>,
72    params: &[Value],
73    _result: &mut [Value],
74) -> Result<(), TrapCode> {
75    let (x_ptr, y_ptr) = (
76        params[0].i32().unwrap() as u32,
77        params[1].i32().unwrap() as u32,
78    );
79    let mut ac0 = [0u8; NUM_BYTES];
80    let mut ac1 = [0u8; NUM_BYTES];
81    ctx.memory_read(x_ptr as usize, &mut ac0)?;
82    ctx.memory_read(x_ptr as usize + NUM_BYTES, &mut ac1)?;
83    let mut bc0 = [0u8; NUM_BYTES];
84    let mut bc1 = [0u8; NUM_BYTES];
85    ctx.memory_read(y_ptr as usize, &mut bc0)?;
86    ctx.memory_read(y_ptr as usize + NUM_BYTES, &mut bc1)?;
87
88    let (res0, res1) = syscall_tower_fp2_add_sub_mul_impl::<NUM_BYTES, P, FIELD_OP>(
89        ac0,
90        ac1,
91        bc0,
92        bc1,
93    )
94    .map_err(|exit_code| syscall_process_exit_code(ctx, exit_code))?;
95
96    ctx.memory_write(x_ptr as usize, &res0)?;
97    ctx.memory_write(x_ptr as usize + NUM_BYTES, &res1)?;
98    Ok(())
99}
100
101pub fn syscall_tower_fp2_bn254_add_impl(
102    ac0: [u8; BN254_FP_SIZE],
103    ac1: [u8; BN254_FP_SIZE],
104    bc0: [u8; BN254_FP_SIZE],
105    bc1: [u8; BN254_FP_SIZE],
106) -> Result<([u8; BN254_FP_SIZE], [u8; BN254_FP_SIZE]), ExitCode> {
107    syscall_tower_fp2_add_sub_mul_impl::<BN254_FP_SIZE, Bn254BaseField, FP_FIELD_ADD>(
108        ac0, ac1, bc0, bc1,
109    )
110}
111pub fn syscall_tower_fp2_bn254_sub_impl(
112    ac0: [u8; BN254_FP_SIZE],
113    ac1: [u8; BN254_FP_SIZE],
114    bc0: [u8; BN254_FP_SIZE],
115    bc1: [u8; BN254_FP_SIZE],
116) -> Result<([u8; BN254_FP_SIZE], [u8; BN254_FP_SIZE]), ExitCode> {
117    syscall_tower_fp2_add_sub_mul_impl::<BN254_FP_SIZE, Bn254BaseField, FP_FIELD_SUB>(
118        ac0, ac1, bc0, bc1,
119    )
120}
121pub fn syscall_tower_fp2_bn254_mul_impl(
122    ac0: [u8; BN254_FP_SIZE],
123    ac1: [u8; BN254_FP_SIZE],
124    bc0: [u8; BN254_FP_SIZE],
125    bc1: [u8; BN254_FP_SIZE],
126) -> Result<([u8; BN254_FP_SIZE], [u8; BN254_FP_SIZE]), ExitCode> {
127    syscall_tower_fp2_add_sub_mul_impl::<BN254_FP_SIZE, Bn254BaseField, FP_FIELD_MUL>(
128        ac0, ac1, bc0, bc1,
129    )
130}
131pub fn syscall_tower_fp2_bls12381_add_impl(
132    ac0: [u8; BLS12381_FP_SIZE],
133    ac1: [u8; BLS12381_FP_SIZE],
134    bc0: [u8; BLS12381_FP_SIZE],
135    bc1: [u8; BLS12381_FP_SIZE],
136) -> Result<([u8; BLS12381_FP_SIZE], [u8; BLS12381_FP_SIZE]), ExitCode> {
137    syscall_tower_fp2_add_sub_mul_impl::<BLS12381_FP_SIZE, Bls12381BaseField, FP_FIELD_ADD>(
138        ac0, ac1, bc0, bc1,
139    )
140}
141pub fn syscall_tower_fp2_bls12381_sub_impl(
142    ac0: [u8; BLS12381_FP_SIZE],
143    ac1: [u8; BLS12381_FP_SIZE],
144    bc0: [u8; BLS12381_FP_SIZE],
145    bc1: [u8; BLS12381_FP_SIZE],
146) -> Result<([u8; BLS12381_FP_SIZE], [u8; BLS12381_FP_SIZE]), ExitCode> {
147    syscall_tower_fp2_add_sub_mul_impl::<BLS12381_FP_SIZE, Bls12381BaseField, FP_FIELD_SUB>(
148        ac0, ac1, bc0, bc1,
149    )
150}
151pub fn syscall_tower_fp2_bls12381_mul_impl(
152    ac0: [u8; BLS12381_FP_SIZE],
153    ac1: [u8; BLS12381_FP_SIZE],
154    bc0: [u8; BLS12381_FP_SIZE],
155    bc1: [u8; BLS12381_FP_SIZE],
156) -> Result<([u8; BLS12381_FP_SIZE], [u8; BLS12381_FP_SIZE]), ExitCode> {
157    syscall_tower_fp2_add_sub_mul_impl::<BLS12381_FP_SIZE, Bls12381BaseField, FP_FIELD_MUL>(
158        ac0, ac1, bc0, bc1,
159    )
160}
161
162pub(crate) fn syscall_tower_fp2_add_sub_mul_impl<
163    const NUM_BYTES: usize,
164    P: FpOpField,
165    const FIELD_OP: u32,
166>(
167    ac0: [u8; NUM_BYTES],
168    ac1: [u8; NUM_BYTES],
169    bc0: [u8; NUM_BYTES],
170    bc1: [u8; NUM_BYTES],
171) -> Result<([u8; NUM_BYTES], [u8; NUM_BYTES]), ExitCode> {
172    let ac0 = &BigUint::from_bytes_le(&ac0);
173    let ac1 = &BigUint::from_bytes_le(&ac1);
174    let bc0 = &BigUint::from_bytes_le(&bc0);
175    let bc1 = &BigUint::from_bytes_le(&bc1);
176    let modulus = &BigUint::from_bytes_le(P::MODULUS);
177    let (c0, c1) = match FIELD_OP {
178        FP_FIELD_ADD => ((ac0 + bc0) % modulus, (ac1 + bc1) % modulus),
179        FP_FIELD_SUB => {
180            if ac0 + modulus < *bc0 || ac1 + modulus < *bc1 {
181                return Err(ExitCode::MalformedBuiltinParams);
182            }
183            (
184                (ac0 + modulus - bc0) % modulus,
185                (ac1 + modulus - bc1) % modulus,
186            )
187        }
188        FP_FIELD_MUL => {
189            let c0 = match (ac0 * bc0) % modulus < (ac1 * bc1) % modulus {
190                true => ((modulus + (ac0 * bc0) % modulus) - (ac1 * bc1) % modulus) % modulus,
191                false => ((ac0 * bc0) % modulus - (ac1 * bc1) % modulus) % modulus,
192            };
193            let c1 = ((ac0 * bc1) % modulus + (ac1 * bc0) % modulus) % modulus;
194            (c0, c1)
195        }
196        _ => unreachable!(),
197    };
198    let mut res0 = c0.to_bytes_le();
199    res0.resize(NUM_BYTES, 0);
200    let mut res1 = c1.to_bytes_le();
201    res1.resize(NUM_BYTES, 0);
202    let result: ([u8; NUM_BYTES], [u8; NUM_BYTES]) = (
203        res0.try_into().expect("length checked"),
204        res1.try_into().expect("length checked"),
205    );
206    Ok(result)
207}
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212    use rand::Rng;
213    use std::str::FromStr;
214
215    fn random_bigint<const NUM_BYTES: usize>(modulus: &BigUint) -> BigUint {
216        let mut rng = rand::rng();
217        let mut arr = vec![0u8; NUM_BYTES];
218        for item in arr.iter_mut() {
219            *item = rng.random();
220        }
221        BigUint::from_bytes_le(&arr) % modulus
222    }
223
224    fn big_uint_into_bytes<const NUM_BYTES: usize>(a: &BigUint) -> [u8; NUM_BYTES] {
225        let mut res = a.to_bytes_le();
226        res.resize(NUM_BYTES, 0);
227        res.try_into().expect("length checked")
228    }
229
230    #[test]
231    fn test_tower_fp2_bn254_add_sub_mul() {
232        const MODULUS: &str =
233            "21888242871839275222246405745257275088696311157297823662689037894645226208583";
234        let modulus = BigUint::from_str(MODULUS).unwrap();
235
236        let add =
237            |ac0: &BigUint, ac1: &BigUint, bc0: &BigUint, bc1: &BigUint| -> (BigUint, BigUint) {
238                let (res0, res1) = syscall_tower_fp2_bn254_add_impl(
239                    big_uint_into_bytes(ac0),
240                    big_uint_into_bytes(ac1),
241                    big_uint_into_bytes(bc0),
242                    big_uint_into_bytes(bc1),
243                )
244                .unwrap();
245                (BigUint::from_bytes_le(&res0), BigUint::from_bytes_le(&res1))
246            };
247        let sub =
248            |ac0: &BigUint, ac1: &BigUint, bc0: &BigUint, bc1: &BigUint| -> (BigUint, BigUint) {
249                let (res0, res1) = syscall_tower_fp2_bn254_sub_impl(
250                    big_uint_into_bytes(ac0),
251                    big_uint_into_bytes(ac1),
252                    big_uint_into_bytes(bc0),
253                    big_uint_into_bytes(bc1),
254                )
255                .unwrap();
256                (BigUint::from_bytes_le(&res0), BigUint::from_bytes_le(&res1))
257            };
258        let mul =
259            |ac0: &BigUint, ac1: &BigUint, bc0: &BigUint, bc1: &BigUint| -> (BigUint, BigUint) {
260                let (res0, res1) = syscall_tower_fp2_bn254_mul_impl(
261                    big_uint_into_bytes(ac0),
262                    big_uint_into_bytes(ac1),
263                    big_uint_into_bytes(bc0),
264                    big_uint_into_bytes(bc1),
265                )
266                .unwrap();
267                (BigUint::from_bytes_le(&res0), BigUint::from_bytes_le(&res1))
268            };
269
270        let (zero0, zero1) = add(
271            &BigUint::ZERO,
272            &BigUint::ZERO,
273            &BigUint::ZERO,
274            &BigUint::ZERO,
275        );
276        assert_eq!(zero0, BigUint::ZERO);
277        assert_eq!(zero1, BigUint::ZERO);
278
279        for _ in 0..10 {
280            let ac0 = random_bigint::<32>(&modulus);
281            let ac1 = random_bigint::<32>(&modulus);
282            let bc0 = random_bigint::<32>(&modulus);
283            let bc1 = random_bigint::<32>(&modulus);
284
285            // Fp2 Addition test
286            let c0 = (&ac0 + &bc0) % &modulus;
287            let c1 = (&ac1 + &bc1) % &modulus;
288
289            let (res_c0, res_c1) = add(&ac0, &ac1, &bc0, &bc1);
290
291            assert_eq!(c0, &res_c0 % &modulus);
292            assert_eq!(c1, &res_c1 % &modulus);
293
294            // Fp2 Subtraction test
295            let c0 = (&ac0 + &modulus - &bc0) % &modulus;
296            let c1 = (&ac1 + &modulus - &bc1) % &modulus;
297
298            let (res_c0, res_c1) = sub(&ac0, &ac1, &bc0, &bc1);
299
300            assert_eq!(c0, &res_c0 % &modulus);
301            assert_eq!(c1, &res_c1 % &modulus);
302        }
303
304        for _ in 0..10 {
305            let ac0 = random_bigint::<32>(&modulus);
306            let ac1 = random_bigint::<32>(&modulus);
307            let bc0 = random_bigint::<32>(&modulus);
308            let bc1 = random_bigint::<32>(&modulus);
309
310            let ac0_bc0_mod = (&ac0 * &bc0) % &modulus;
311            let ac1_bc1_mod = (&ac1 * &bc1) % &modulus;
312
313            let c0 = if ac0_bc0_mod < ac1_bc1_mod {
314                (&modulus + ac0_bc0_mod - ac1_bc1_mod) % &modulus
315            } else {
316                (ac0_bc0_mod - ac1_bc1_mod) % &modulus
317            };
318
319            let c1 = ((&ac0 * &bc1) % &modulus + (&ac1 * &bc0) % &modulus) % &modulus;
320
321            let (res_c0, res_c1) = mul(&ac0, &ac1, &bc0, &bc1);
322
323            assert_eq!(c0, &res_c0 % &modulus);
324            assert_eq!(c1, &res_c1 % &modulus);
325        }
326    }
327
328    #[test]
329    fn test_tower_fp2_bls12381_add_sub_mul() {
330        const MODULUS: &str =
331            "4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559787";
332        let modulus = BigUint::from_str(MODULUS).unwrap();
333
334        let add =
335            |ac0: &BigUint, ac1: &BigUint, bc0: &BigUint, bc1: &BigUint| -> (BigUint, BigUint) {
336                let (res0, res1) = syscall_tower_fp2_bls12381_add_impl(
337                    big_uint_into_bytes(ac0),
338                    big_uint_into_bytes(ac1),
339                    big_uint_into_bytes(bc0),
340                    big_uint_into_bytes(bc1),
341                )
342                .unwrap();
343                (BigUint::from_bytes_le(&res0), BigUint::from_bytes_le(&res1))
344            };
345        let sub =
346            |ac0: &BigUint, ac1: &BigUint, bc0: &BigUint, bc1: &BigUint| -> (BigUint, BigUint) {
347                let (res0, res1) = syscall_tower_fp2_bls12381_sub_impl(
348                    big_uint_into_bytes(ac0),
349                    big_uint_into_bytes(ac1),
350                    big_uint_into_bytes(bc0),
351                    big_uint_into_bytes(bc1),
352                )
353                .unwrap();
354                (BigUint::from_bytes_le(&res0), BigUint::from_bytes_le(&res1))
355            };
356        let mul =
357            |ac0: &BigUint, ac1: &BigUint, bc0: &BigUint, bc1: &BigUint| -> (BigUint, BigUint) {
358                let (res0, res1) = syscall_tower_fp2_bls12381_mul_impl(
359                    big_uint_into_bytes(ac0),
360                    big_uint_into_bytes(ac1),
361                    big_uint_into_bytes(bc0),
362                    big_uint_into_bytes(bc1),
363                )
364                .unwrap();
365                (BigUint::from_bytes_le(&res0), BigUint::from_bytes_le(&res1))
366            };
367
368        let (zero0, zero1) = add(
369            &BigUint::ZERO,
370            &BigUint::ZERO,
371            &BigUint::ZERO,
372            &BigUint::ZERO,
373        );
374        assert_eq!(zero0, BigUint::ZERO);
375        assert_eq!(zero1, BigUint::ZERO);
376
377        for _ in 0..10 {
378            let ac0 = random_bigint::<48>(&modulus);
379            let ac1 = random_bigint::<48>(&modulus);
380            let bc0 = random_bigint::<48>(&modulus);
381            let bc1 = random_bigint::<48>(&modulus);
382
383            // Fp2 Addition test
384            let c0 = (&ac0 + &bc0) % &modulus;
385            let c1 = (&ac1 + &bc1) % &modulus;
386
387            let (res_c0, res_c1) = add(&ac0, &ac1, &bc0, &bc1);
388
389            assert_eq!(c0, &res_c0 % &modulus);
390            assert_eq!(c1, &res_c1 % &modulus);
391
392            // Fp2 Subtraction test
393            let c0 = (&ac0 + &modulus - &bc0) % &modulus;
394            let c1 = (&ac1 + &modulus - &bc1) % &modulus;
395
396            let (res_c0, res_c1) = sub(&ac0, &ac1, &bc0, &bc1);
397
398            assert_eq!(c0, &res_c0 % &modulus);
399            assert_eq!(c1, &res_c1 % &modulus);
400        }
401
402        for _ in 0..10 {
403            let ac0 = random_bigint::<48>(&modulus);
404            let ac1 = random_bigint::<48>(&modulus);
405            let bc0 = random_bigint::<48>(&modulus);
406            let bc1 = random_bigint::<48>(&modulus);
407
408            let ac0_bc0_mod = (&ac0 * &bc0) % &modulus;
409            let ac1_bc1_mod = (&ac1 * &bc1) % &modulus;
410
411            let c0 = if ac0_bc0_mod < ac1_bc1_mod {
412                (&modulus + ac0_bc0_mod - ac1_bc1_mod) % &modulus
413            } else {
414                (ac0_bc0_mod - ac1_bc1_mod) % &modulus
415            };
416
417            let c1 = ((&ac0 * &bc1) % &modulus + (&ac1 * &bc0) % &modulus) % &modulus;
418
419            let (res_c0, res_c1) = mul(&ac0, &ac1, &bc0, &bc1);
420
421            assert_eq!(c0, &res_c0 % &modulus);
422            assert_eq!(c1, &res_c1 % &modulus);
423        }
424    }
425
426    #[test]
427    #[should_panic(
428        expected = "called `Result::unwrap()` on an `Err` value: MalformedBuiltinParams"
429    )]
430    fn test_tower_fp2_bn254_sub_panics_on_non_canonical_bc1() {
431        const MODULUS: &str =
432            "21888242871839275222246405745257275088696311157297823662689037894645226208583";
433        let modulus = BigUint::from_str(MODULUS).unwrap();
434        let non_canonical = modulus.clone() + BigUint::from(1u32);
435
436        let _ = syscall_tower_fp2_bn254_sub_impl(
437            big_uint_into_bytes(&BigUint::ZERO),
438            big_uint_into_bytes(&BigUint::ZERO),
439            big_uint_into_bytes(&BigUint::ZERO),
440            big_uint_into_bytes(&non_canonical),
441        )
442        .unwrap();
443    }
444}