1use crate::{syscall_handler::syscall_process_exit_code, RuntimeContext};
2use fluentbase_types::{ExitCode, BLS12381_FP_SIZE, BN254_FP_SIZE};
3use num::BigUint;
4use rwasm::{StoreTr, TrapCode, Value};
5use sp1_curves::weierstrass::{bls12_381::Bls12381BaseField, bn254::Bn254BaseField, FpOpField};
6
7pub fn syscall_tower_fp2_bn254_add_handler(
8 ctx: &mut impl StoreTr<RuntimeContext>,
9 params: &[Value],
10 _result: &mut [Value],
11) -> Result<(), TrapCode> {
12 syscall_tower_fp2_add_sub_mul_handler::<BN254_FP_SIZE, Bn254BaseField, FP_FIELD_ADD>(
13 ctx, params, _result,
14 )
15}
16pub fn syscall_tower_fp2_bn254_sub_handler(
17 ctx: &mut impl StoreTr<RuntimeContext>,
18 params: &[Value],
19 _result: &mut [Value],
20) -> Result<(), TrapCode> {
21 syscall_tower_fp2_add_sub_mul_handler::<BN254_FP_SIZE, Bn254BaseField, FP_FIELD_SUB>(
22 ctx, params, _result,
23 )
24}
25pub fn syscall_tower_fp2_bn254_mul_handler(
26 ctx: &mut impl StoreTr<RuntimeContext>,
27 params: &[Value],
28 _result: &mut [Value],
29) -> Result<(), TrapCode> {
30 syscall_tower_fp2_add_sub_mul_handler::<BN254_FP_SIZE, Bn254BaseField, FP_FIELD_MUL>(
31 ctx, params, _result,
32 )
33}
34pub fn syscall_tower_fp2_bls12381_add_handler(
35 ctx: &mut impl StoreTr<RuntimeContext>,
36 params: &[Value],
37 _result: &mut [Value],
38) -> Result<(), TrapCode> {
39 syscall_tower_fp2_add_sub_mul_handler::<BLS12381_FP_SIZE, Bls12381BaseField, FP_FIELD_ADD>(
40 ctx, params, _result,
41 )
42}
43pub fn syscall_tower_fp2_bls12381_sub_handler(
44 ctx: &mut impl StoreTr<RuntimeContext>,
45 params: &[Value],
46 _result: &mut [Value],
47) -> Result<(), TrapCode> {
48 syscall_tower_fp2_add_sub_mul_handler::<BLS12381_FP_SIZE, Bls12381BaseField, FP_FIELD_SUB>(
49 ctx, params, _result,
50 )
51}
52pub fn syscall_tower_fp2_bls12381_mul_handler(
53 ctx: &mut impl StoreTr<RuntimeContext>,
54 params: &[Value],
55 _result: &mut [Value],
56) -> Result<(), TrapCode> {
57 syscall_tower_fp2_add_sub_mul_handler::<BLS12381_FP_SIZE, Bls12381BaseField, FP_FIELD_MUL>(
58 ctx, params, _result,
59 )
60}
61
62const FP_FIELD_ADD: u32 = 0x01;
63const FP_FIELD_SUB: u32 = 0x02;
64const FP_FIELD_MUL: u32 = 0x03;
65
66pub(crate) fn syscall_tower_fp2_add_sub_mul_handler<
67 const NUM_BYTES: usize,
68 P: FpOpField,
69 const FIELD_OP: u32,
70>(
71 ctx: &mut impl StoreTr<RuntimeContext>,
72 params: &[Value],
73 _result: &mut [Value],
74) -> Result<(), TrapCode> {
75 let (x_ptr, y_ptr) = (
76 params[0].i32().unwrap() as u32,
77 params[1].i32().unwrap() as u32,
78 );
79 let mut ac0 = [0u8; NUM_BYTES];
80 let mut ac1 = [0u8; NUM_BYTES];
81 ctx.memory_read(x_ptr as usize, &mut ac0)?;
82 ctx.memory_read(x_ptr as usize + NUM_BYTES, &mut ac1)?;
83 let mut bc0 = [0u8; NUM_BYTES];
84 let mut bc1 = [0u8; NUM_BYTES];
85 ctx.memory_read(y_ptr as usize, &mut bc0)?;
86 ctx.memory_read(y_ptr as usize + NUM_BYTES, &mut bc1)?;
87
88 let (res0, res1) = syscall_tower_fp2_add_sub_mul_impl::<NUM_BYTES, P, FIELD_OP>(
89 ac0,
90 ac1,
91 bc0,
92 bc1,
93 )
94 .map_err(|exit_code| syscall_process_exit_code(ctx, exit_code))?;
95
96 ctx.memory_write(x_ptr as usize, &res0)?;
97 ctx.memory_write(x_ptr as usize + NUM_BYTES, &res1)?;
98 Ok(())
99}
100
101pub fn syscall_tower_fp2_bn254_add_impl(
102 ac0: [u8; BN254_FP_SIZE],
103 ac1: [u8; BN254_FP_SIZE],
104 bc0: [u8; BN254_FP_SIZE],
105 bc1: [u8; BN254_FP_SIZE],
106) -> Result<([u8; BN254_FP_SIZE], [u8; BN254_FP_SIZE]), ExitCode> {
107 syscall_tower_fp2_add_sub_mul_impl::<BN254_FP_SIZE, Bn254BaseField, FP_FIELD_ADD>(
108 ac0, ac1, bc0, bc1,
109 )
110}
111pub fn syscall_tower_fp2_bn254_sub_impl(
112 ac0: [u8; BN254_FP_SIZE],
113 ac1: [u8; BN254_FP_SIZE],
114 bc0: [u8; BN254_FP_SIZE],
115 bc1: [u8; BN254_FP_SIZE],
116) -> Result<([u8; BN254_FP_SIZE], [u8; BN254_FP_SIZE]), ExitCode> {
117 syscall_tower_fp2_add_sub_mul_impl::<BN254_FP_SIZE, Bn254BaseField, FP_FIELD_SUB>(
118 ac0, ac1, bc0, bc1,
119 )
120}
121pub fn syscall_tower_fp2_bn254_mul_impl(
122 ac0: [u8; BN254_FP_SIZE],
123 ac1: [u8; BN254_FP_SIZE],
124 bc0: [u8; BN254_FP_SIZE],
125 bc1: [u8; BN254_FP_SIZE],
126) -> Result<([u8; BN254_FP_SIZE], [u8; BN254_FP_SIZE]), ExitCode> {
127 syscall_tower_fp2_add_sub_mul_impl::<BN254_FP_SIZE, Bn254BaseField, FP_FIELD_MUL>(
128 ac0, ac1, bc0, bc1,
129 )
130}
131pub fn syscall_tower_fp2_bls12381_add_impl(
132 ac0: [u8; BLS12381_FP_SIZE],
133 ac1: [u8; BLS12381_FP_SIZE],
134 bc0: [u8; BLS12381_FP_SIZE],
135 bc1: [u8; BLS12381_FP_SIZE],
136) -> Result<([u8; BLS12381_FP_SIZE], [u8; BLS12381_FP_SIZE]), ExitCode> {
137 syscall_tower_fp2_add_sub_mul_impl::<BLS12381_FP_SIZE, Bls12381BaseField, FP_FIELD_ADD>(
138 ac0, ac1, bc0, bc1,
139 )
140}
141pub fn syscall_tower_fp2_bls12381_sub_impl(
142 ac0: [u8; BLS12381_FP_SIZE],
143 ac1: [u8; BLS12381_FP_SIZE],
144 bc0: [u8; BLS12381_FP_SIZE],
145 bc1: [u8; BLS12381_FP_SIZE],
146) -> Result<([u8; BLS12381_FP_SIZE], [u8; BLS12381_FP_SIZE]), ExitCode> {
147 syscall_tower_fp2_add_sub_mul_impl::<BLS12381_FP_SIZE, Bls12381BaseField, FP_FIELD_SUB>(
148 ac0, ac1, bc0, bc1,
149 )
150}
151pub fn syscall_tower_fp2_bls12381_mul_impl(
152 ac0: [u8; BLS12381_FP_SIZE],
153 ac1: [u8; BLS12381_FP_SIZE],
154 bc0: [u8; BLS12381_FP_SIZE],
155 bc1: [u8; BLS12381_FP_SIZE],
156) -> Result<([u8; BLS12381_FP_SIZE], [u8; BLS12381_FP_SIZE]), ExitCode> {
157 syscall_tower_fp2_add_sub_mul_impl::<BLS12381_FP_SIZE, Bls12381BaseField, FP_FIELD_MUL>(
158 ac0, ac1, bc0, bc1,
159 )
160}
161
162pub(crate) fn syscall_tower_fp2_add_sub_mul_impl<
163 const NUM_BYTES: usize,
164 P: FpOpField,
165 const FIELD_OP: u32,
166>(
167 ac0: [u8; NUM_BYTES],
168 ac1: [u8; NUM_BYTES],
169 bc0: [u8; NUM_BYTES],
170 bc1: [u8; NUM_BYTES],
171) -> Result<([u8; NUM_BYTES], [u8; NUM_BYTES]), ExitCode> {
172 let ac0 = &BigUint::from_bytes_le(&ac0);
173 let ac1 = &BigUint::from_bytes_le(&ac1);
174 let bc0 = &BigUint::from_bytes_le(&bc0);
175 let bc1 = &BigUint::from_bytes_le(&bc1);
176 let modulus = &BigUint::from_bytes_le(P::MODULUS);
177 let (c0, c1) = match FIELD_OP {
178 FP_FIELD_ADD => ((ac0 + bc0) % modulus, (ac1 + bc1) % modulus),
179 FP_FIELD_SUB => {
180 if ac0 + modulus < *bc0 || ac1 + modulus < *bc1 {
181 return Err(ExitCode::MalformedBuiltinParams);
182 }
183 (
184 (ac0 + modulus - bc0) % modulus,
185 (ac1 + modulus - bc1) % modulus,
186 )
187 }
188 FP_FIELD_MUL => {
189 let c0 = match (ac0 * bc0) % modulus < (ac1 * bc1) % modulus {
190 true => ((modulus + (ac0 * bc0) % modulus) - (ac1 * bc1) % modulus) % modulus,
191 false => ((ac0 * bc0) % modulus - (ac1 * bc1) % modulus) % modulus,
192 };
193 let c1 = ((ac0 * bc1) % modulus + (ac1 * bc0) % modulus) % modulus;
194 (c0, c1)
195 }
196 _ => unreachable!(),
197 };
198 let mut res0 = c0.to_bytes_le();
199 res0.resize(NUM_BYTES, 0);
200 let mut res1 = c1.to_bytes_le();
201 res1.resize(NUM_BYTES, 0);
202 let result: ([u8; NUM_BYTES], [u8; NUM_BYTES]) = (
203 res0.try_into().expect("length checked"),
204 res1.try_into().expect("length checked"),
205 );
206 Ok(result)
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212 use rand::Rng;
213 use std::str::FromStr;
214
215 fn random_bigint<const NUM_BYTES: usize>(modulus: &BigUint) -> BigUint {
216 let mut rng = rand::rng();
217 let mut arr = vec![0u8; NUM_BYTES];
218 for item in arr.iter_mut() {
219 *item = rng.random();
220 }
221 BigUint::from_bytes_le(&arr) % modulus
222 }
223
224 fn big_uint_into_bytes<const NUM_BYTES: usize>(a: &BigUint) -> [u8; NUM_BYTES] {
225 let mut res = a.to_bytes_le();
226 res.resize(NUM_BYTES, 0);
227 res.try_into().expect("length checked")
228 }
229
230 #[test]
231 fn test_tower_fp2_bn254_add_sub_mul() {
232 const MODULUS: &str =
233 "21888242871839275222246405745257275088696311157297823662689037894645226208583";
234 let modulus = BigUint::from_str(MODULUS).unwrap();
235
236 let add =
237 |ac0: &BigUint, ac1: &BigUint, bc0: &BigUint, bc1: &BigUint| -> (BigUint, BigUint) {
238 let (res0, res1) = syscall_tower_fp2_bn254_add_impl(
239 big_uint_into_bytes(ac0),
240 big_uint_into_bytes(ac1),
241 big_uint_into_bytes(bc0),
242 big_uint_into_bytes(bc1),
243 )
244 .unwrap();
245 (BigUint::from_bytes_le(&res0), BigUint::from_bytes_le(&res1))
246 };
247 let sub =
248 |ac0: &BigUint, ac1: &BigUint, bc0: &BigUint, bc1: &BigUint| -> (BigUint, BigUint) {
249 let (res0, res1) = syscall_tower_fp2_bn254_sub_impl(
250 big_uint_into_bytes(ac0),
251 big_uint_into_bytes(ac1),
252 big_uint_into_bytes(bc0),
253 big_uint_into_bytes(bc1),
254 )
255 .unwrap();
256 (BigUint::from_bytes_le(&res0), BigUint::from_bytes_le(&res1))
257 };
258 let mul =
259 |ac0: &BigUint, ac1: &BigUint, bc0: &BigUint, bc1: &BigUint| -> (BigUint, BigUint) {
260 let (res0, res1) = syscall_tower_fp2_bn254_mul_impl(
261 big_uint_into_bytes(ac0),
262 big_uint_into_bytes(ac1),
263 big_uint_into_bytes(bc0),
264 big_uint_into_bytes(bc1),
265 )
266 .unwrap();
267 (BigUint::from_bytes_le(&res0), BigUint::from_bytes_le(&res1))
268 };
269
270 let (zero0, zero1) = add(
271 &BigUint::ZERO,
272 &BigUint::ZERO,
273 &BigUint::ZERO,
274 &BigUint::ZERO,
275 );
276 assert_eq!(zero0, BigUint::ZERO);
277 assert_eq!(zero1, BigUint::ZERO);
278
279 for _ in 0..10 {
280 let ac0 = random_bigint::<32>(&modulus);
281 let ac1 = random_bigint::<32>(&modulus);
282 let bc0 = random_bigint::<32>(&modulus);
283 let bc1 = random_bigint::<32>(&modulus);
284
285 let c0 = (&ac0 + &bc0) % &modulus;
287 let c1 = (&ac1 + &bc1) % &modulus;
288
289 let (res_c0, res_c1) = add(&ac0, &ac1, &bc0, &bc1);
290
291 assert_eq!(c0, &res_c0 % &modulus);
292 assert_eq!(c1, &res_c1 % &modulus);
293
294 let c0 = (&ac0 + &modulus - &bc0) % &modulus;
296 let c1 = (&ac1 + &modulus - &bc1) % &modulus;
297
298 let (res_c0, res_c1) = sub(&ac0, &ac1, &bc0, &bc1);
299
300 assert_eq!(c0, &res_c0 % &modulus);
301 assert_eq!(c1, &res_c1 % &modulus);
302 }
303
304 for _ in 0..10 {
305 let ac0 = random_bigint::<32>(&modulus);
306 let ac1 = random_bigint::<32>(&modulus);
307 let bc0 = random_bigint::<32>(&modulus);
308 let bc1 = random_bigint::<32>(&modulus);
309
310 let ac0_bc0_mod = (&ac0 * &bc0) % &modulus;
311 let ac1_bc1_mod = (&ac1 * &bc1) % &modulus;
312
313 let c0 = if ac0_bc0_mod < ac1_bc1_mod {
314 (&modulus + ac0_bc0_mod - ac1_bc1_mod) % &modulus
315 } else {
316 (ac0_bc0_mod - ac1_bc1_mod) % &modulus
317 };
318
319 let c1 = ((&ac0 * &bc1) % &modulus + (&ac1 * &bc0) % &modulus) % &modulus;
320
321 let (res_c0, res_c1) = mul(&ac0, &ac1, &bc0, &bc1);
322
323 assert_eq!(c0, &res_c0 % &modulus);
324 assert_eq!(c1, &res_c1 % &modulus);
325 }
326 }
327
328 #[test]
329 fn test_tower_fp2_bls12381_add_sub_mul() {
330 const MODULUS: &str =
331 "4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559787";
332 let modulus = BigUint::from_str(MODULUS).unwrap();
333
334 let add =
335 |ac0: &BigUint, ac1: &BigUint, bc0: &BigUint, bc1: &BigUint| -> (BigUint, BigUint) {
336 let (res0, res1) = syscall_tower_fp2_bls12381_add_impl(
337 big_uint_into_bytes(ac0),
338 big_uint_into_bytes(ac1),
339 big_uint_into_bytes(bc0),
340 big_uint_into_bytes(bc1),
341 )
342 .unwrap();
343 (BigUint::from_bytes_le(&res0), BigUint::from_bytes_le(&res1))
344 };
345 let sub =
346 |ac0: &BigUint, ac1: &BigUint, bc0: &BigUint, bc1: &BigUint| -> (BigUint, BigUint) {
347 let (res0, res1) = syscall_tower_fp2_bls12381_sub_impl(
348 big_uint_into_bytes(ac0),
349 big_uint_into_bytes(ac1),
350 big_uint_into_bytes(bc0),
351 big_uint_into_bytes(bc1),
352 )
353 .unwrap();
354 (BigUint::from_bytes_le(&res0), BigUint::from_bytes_le(&res1))
355 };
356 let mul =
357 |ac0: &BigUint, ac1: &BigUint, bc0: &BigUint, bc1: &BigUint| -> (BigUint, BigUint) {
358 let (res0, res1) = syscall_tower_fp2_bls12381_mul_impl(
359 big_uint_into_bytes(ac0),
360 big_uint_into_bytes(ac1),
361 big_uint_into_bytes(bc0),
362 big_uint_into_bytes(bc1),
363 )
364 .unwrap();
365 (BigUint::from_bytes_le(&res0), BigUint::from_bytes_le(&res1))
366 };
367
368 let (zero0, zero1) = add(
369 &BigUint::ZERO,
370 &BigUint::ZERO,
371 &BigUint::ZERO,
372 &BigUint::ZERO,
373 );
374 assert_eq!(zero0, BigUint::ZERO);
375 assert_eq!(zero1, BigUint::ZERO);
376
377 for _ in 0..10 {
378 let ac0 = random_bigint::<48>(&modulus);
379 let ac1 = random_bigint::<48>(&modulus);
380 let bc0 = random_bigint::<48>(&modulus);
381 let bc1 = random_bigint::<48>(&modulus);
382
383 let c0 = (&ac0 + &bc0) % &modulus;
385 let c1 = (&ac1 + &bc1) % &modulus;
386
387 let (res_c0, res_c1) = add(&ac0, &ac1, &bc0, &bc1);
388
389 assert_eq!(c0, &res_c0 % &modulus);
390 assert_eq!(c1, &res_c1 % &modulus);
391
392 let c0 = (&ac0 + &modulus - &bc0) % &modulus;
394 let c1 = (&ac1 + &modulus - &bc1) % &modulus;
395
396 let (res_c0, res_c1) = sub(&ac0, &ac1, &bc0, &bc1);
397
398 assert_eq!(c0, &res_c0 % &modulus);
399 assert_eq!(c1, &res_c1 % &modulus);
400 }
401
402 for _ in 0..10 {
403 let ac0 = random_bigint::<48>(&modulus);
404 let ac1 = random_bigint::<48>(&modulus);
405 let bc0 = random_bigint::<48>(&modulus);
406 let bc1 = random_bigint::<48>(&modulus);
407
408 let ac0_bc0_mod = (&ac0 * &bc0) % &modulus;
409 let ac1_bc1_mod = (&ac1 * &bc1) % &modulus;
410
411 let c0 = if ac0_bc0_mod < ac1_bc1_mod {
412 (&modulus + ac0_bc0_mod - ac1_bc1_mod) % &modulus
413 } else {
414 (ac0_bc0_mod - ac1_bc1_mod) % &modulus
415 };
416
417 let c1 = ((&ac0 * &bc1) % &modulus + (&ac1 * &bc0) % &modulus) % &modulus;
418
419 let (res_c0, res_c1) = mul(&ac0, &ac1, &bc0, &bc1);
420
421 assert_eq!(c0, &res_c0 % &modulus);
422 assert_eq!(c1, &res_c1 % &modulus);
423 }
424 }
425
426 #[test]
427 #[should_panic(
428 expected = "called `Result::unwrap()` on an `Err` value: MalformedBuiltinParams"
429 )]
430 fn test_tower_fp2_bn254_sub_panics_on_non_canonical_bc1() {
431 const MODULUS: &str =
432 "21888242871839275222246405745257275088696311157297823662689037894645226208583";
433 let modulus = BigUint::from_str(MODULUS).unwrap();
434 let non_canonical = modulus.clone() + BigUint::from(1u32);
435
436 let _ = syscall_tower_fp2_bn254_sub_impl(
437 big_uint_into_bytes(&BigUint::ZERO),
438 big_uint_into_bytes(&BigUint::ZERO),
439 big_uint_into_bytes(&BigUint::ZERO),
440 big_uint_into_bytes(&non_canonical),
441 )
442 .unwrap();
443 }
444}