1use crate::{syscall_handler::syscall_process_exit_code, RuntimeContext};
2use fluentbase_types::{ExitCode, BLS12381_FP_SIZE, BN254_FP_SIZE};
3use num::BigUint;
4use rwasm::{StoreTr, TrapCode, Value};
5use sp1_curves::weierstrass::{bls12_381::Bls12381BaseField, bn254::Bn254BaseField, FpOpField};
6
7pub fn syscall_tower_fp1_bn254_add_handler(
8 ctx: &mut impl StoreTr<RuntimeContext>,
9 params: &[Value],
10 _result: &mut [Value],
11) -> Result<(), TrapCode> {
12 syscall_tower_fp1_add_sub_mul_handler::<BN254_FP_SIZE, Bn254BaseField, FP_FIELD_ADD>(
13 ctx, params, _result,
14 )
15}
16pub fn syscall_tower_fp1_bn254_sub_handler(
17 ctx: &mut impl StoreTr<RuntimeContext>,
18 params: &[Value],
19 _result: &mut [Value],
20) -> Result<(), TrapCode> {
21 syscall_tower_fp1_add_sub_mul_handler::<BN254_FP_SIZE, Bn254BaseField, FP_FIELD_SUB>(
22 ctx, params, _result,
23 )
24}
25pub fn syscall_tower_fp1_bn254_mul_handler(
26 ctx: &mut impl StoreTr<RuntimeContext>,
27 params: &[Value],
28 _result: &mut [Value],
29) -> Result<(), TrapCode> {
30 syscall_tower_fp1_add_sub_mul_handler::<BN254_FP_SIZE, Bn254BaseField, FP_FIELD_MUL>(
31 ctx, params, _result,
32 )
33}
34pub fn syscall_tower_fp1_bls12381_add_handler(
35 ctx: &mut impl StoreTr<RuntimeContext>,
36 params: &[Value],
37 _result: &mut [Value],
38) -> Result<(), TrapCode> {
39 syscall_tower_fp1_add_sub_mul_handler::<BLS12381_FP_SIZE, Bls12381BaseField, FP_FIELD_ADD>(
40 ctx, params, _result,
41 )
42}
43pub fn syscall_tower_fp1_bls12381_sub_handler(
44 ctx: &mut impl StoreTr<RuntimeContext>,
45 params: &[Value],
46 _result: &mut [Value],
47) -> Result<(), TrapCode> {
48 syscall_tower_fp1_add_sub_mul_handler::<BLS12381_FP_SIZE, Bls12381BaseField, FP_FIELD_SUB>(
49 ctx, params, _result,
50 )
51}
52pub fn syscall_tower_fp1_bls12381_mul_handler(
53 ctx: &mut impl StoreTr<RuntimeContext>,
54 params: &[Value],
55 _result: &mut [Value],
56) -> Result<(), TrapCode> {
57 syscall_tower_fp1_add_sub_mul_handler::<BLS12381_FP_SIZE, Bls12381BaseField, FP_FIELD_MUL>(
58 ctx, params, _result,
59 )
60}
61
62const FP_FIELD_ADD: u32 = 0x01;
63const FP_FIELD_SUB: u32 = 0x02;
64const FP_FIELD_MUL: u32 = 0x03;
65
66pub(crate) fn syscall_tower_fp1_add_sub_mul_handler<
67 const NUM_BYTES: usize,
68 P: FpOpField,
69 const FIELD_OP: u32,
70>(
71 ctx: &mut impl StoreTr<RuntimeContext>,
72 params: &[Value],
73 _result: &mut [Value],
74) -> Result<(), TrapCode> {
75 let (x_ptr, y_ptr) = (
76 params[0].i32().unwrap() as u32,
77 params[1].i32().unwrap() as u32,
78 );
79 let mut x = [0u8; NUM_BYTES];
80 ctx.memory_read(x_ptr as usize, &mut x)?;
81 let mut y = [0u8; NUM_BYTES];
82 ctx.memory_read(y_ptr as usize, &mut y)?;
83
84 let result = syscall_tower_fp1_add_sub_mul_impl::<NUM_BYTES, P, FIELD_OP>(
85 x,
86 y,
87 )
88 .map_err(|exit_code| syscall_process_exit_code(ctx, exit_code))?;
89
90 ctx.memory_write(x_ptr as usize, &result)?;
91 Ok(())
92}
93
94pub fn syscall_tower_fp1_bn254_add_impl(
95 x: [u8; BN254_FP_SIZE],
96 y: [u8; BN254_FP_SIZE],
97) -> Result<[u8; BN254_FP_SIZE], ExitCode> {
98 syscall_tower_fp1_add_sub_mul_impl::<BN254_FP_SIZE, Bn254BaseField, FP_FIELD_ADD>(x, y)
99}
100pub fn syscall_tower_fp1_bn254_sub_impl(
101 x: [u8; BN254_FP_SIZE],
102 y: [u8; BN254_FP_SIZE],
103) -> Result<[u8; BN254_FP_SIZE], ExitCode> {
104 syscall_tower_fp1_add_sub_mul_impl::<BN254_FP_SIZE, Bn254BaseField, FP_FIELD_SUB>(x, y)
105}
106pub fn syscall_tower_fp1_bn254_mul_impl(
107 x: [u8; BN254_FP_SIZE],
108 y: [u8; BN254_FP_SIZE],
109) -> Result<[u8; BN254_FP_SIZE], ExitCode> {
110 syscall_tower_fp1_add_sub_mul_impl::<BN254_FP_SIZE, Bn254BaseField, FP_FIELD_MUL>(x, y)
111}
112pub fn syscall_tower_fp1_bls12381_add_impl(
113 x: [u8; BLS12381_FP_SIZE],
114 y: [u8; BLS12381_FP_SIZE],
115) -> Result<[u8; BLS12381_FP_SIZE], ExitCode> {
116 syscall_tower_fp1_add_sub_mul_impl::<BLS12381_FP_SIZE, Bls12381BaseField, FP_FIELD_ADD>(x, y)
117}
118pub fn syscall_tower_fp1_bls12381_sub_impl(
119 x: [u8; BLS12381_FP_SIZE],
120 y: [u8; BLS12381_FP_SIZE],
121) -> Result<[u8; BLS12381_FP_SIZE], ExitCode> {
122 syscall_tower_fp1_add_sub_mul_impl::<BLS12381_FP_SIZE, Bls12381BaseField, FP_FIELD_SUB>(x, y)
123}
124pub fn syscall_tower_fp1_bls12381_mul_impl(
125 x: [u8; BLS12381_FP_SIZE],
126 y: [u8; BLS12381_FP_SIZE],
127) -> Result<[u8; BLS12381_FP_SIZE], ExitCode> {
128 syscall_tower_fp1_add_sub_mul_impl::<BLS12381_FP_SIZE, Bls12381BaseField, FP_FIELD_MUL>(x, y)
129}
130
131pub(crate) fn syscall_tower_fp1_add_sub_mul_impl<
132 const NUM_BYTES: usize,
133 P: FpOpField,
134 const FIELD_OP: u32,
135>(
136 x: [u8; NUM_BYTES],
137 y: [u8; NUM_BYTES],
138) -> Result<[u8; NUM_BYTES], ExitCode> {
139 let modulus = &BigUint::from_bytes_le(P::MODULUS);
140 let a = BigUint::from_bytes_le(&x);
141 let b = BigUint::from_bytes_le(&y);
142 let result = match FIELD_OP {
143 FP_FIELD_ADD => (a + b) % modulus,
144 FP_FIELD_SUB => {
145 if &a + modulus < b {
146 return Err(ExitCode::MalformedBuiltinParams);
147 }
148 ((a + modulus) - b) % modulus
149 }
150 FP_FIELD_MUL => (a * b) % modulus,
151 _ => unreachable!(),
152 };
153 let mut result = result.to_bytes_le();
154 result.resize(NUM_BYTES, 0);
155 let result: [u8; NUM_BYTES] = result.try_into().expect("length checked");
156 Ok(result)
157}
158
159#[cfg(test)]
160mod tests {
161 use super::*;
162 use rand::Rng;
163 use std::str::FromStr;
164
165 fn random_bigint<const NUM_BYTES: usize>(modulus: &BigUint) -> BigUint {
166 let mut rng = rand::rng();
167 let mut arr = vec![0u8; NUM_BYTES];
168 for item in arr.iter_mut() {
169 *item = rng.random();
170 }
171 BigUint::from_bytes_le(&arr) % modulus
172 }
173
174 fn big_uint_into_bytes<const NUM_BYTES: usize>(a: &BigUint) -> [u8; NUM_BYTES] {
175 let mut res = a.to_bytes_le();
176 res.resize(NUM_BYTES, 0);
177 res.try_into().expect("length checked")
178 }
179
180 #[test]
182 fn test_bn254_fp1() {
183 let modulus = BigUint::from_str(
184 "21888242871839275222246405745257275088696311157297823662689037894645226208583",
185 )
186 .unwrap();
187 let zero = BigUint::ZERO;
188 let one = BigUint::from(1u32);
189
190 let add = |a: &BigUint, b: &BigUint| -> BigUint {
191 let result =
192 syscall_tower_fp1_bn254_add_impl(big_uint_into_bytes(a), big_uint_into_bytes(b))
193 .unwrap();
194 BigUint::from_bytes_le(&result)
195 };
196 let sub = |a: &BigUint, b: &BigUint| -> BigUint {
197 let result =
198 syscall_tower_fp1_bn254_sub_impl(big_uint_into_bytes(a), big_uint_into_bytes(b))
199 .unwrap();
200 BigUint::from_bytes_le(&result)
201 };
202 let mul = |a: &BigUint, b: &BigUint| -> BigUint {
203 let result =
204 syscall_tower_fp1_bn254_mul_impl(big_uint_into_bytes(a), big_uint_into_bytes(b))
205 .unwrap();
206 BigUint::from_bytes_le(&result)
207 };
208
209 for _ in 0..10 {
210 let a = random_bigint::<32>(&modulus);
211 let b = random_bigint::<32>(&modulus);
212
213 let result = add(&a, &b) % &modulus;
215 assert_eq!((&a + &b) % &modulus, result);
216
217 let result = add(&a, &zero) % &modulus;
219 assert_eq!((&a + &zero) % &modulus, result);
220
221 let expected_sub = if a < b {
223 ((&a + &modulus) - &b) % &modulus
224 } else {
225 (&a - &b) % &modulus
226 };
227 let result = sub(&a, &b) % &modulus;
228 assert_eq!(expected_sub, result);
229
230 let result = sub(&a, &zero) % &modulus;
232 assert_eq!((&a + &modulus - &zero) % &modulus, result);
233
234 let result = mul(&a, &b) % &modulus;
236 assert_eq!((&a * &b) % &modulus, result);
237
238 let result = mul(&a, &one) % &modulus;
240 assert_eq!((&a * &one) % &modulus, result);
241
242 let result = mul(&a, &zero) % &modulus;
244 assert_eq!((&a * &zero) % &modulus, result);
245 }
246 }
247
248 #[test]
250 fn test_bls12381_fp1() {
251 let modulus = BigUint::from_str("4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559787").unwrap();
252 let zero = BigUint::ZERO;
253 let one = BigUint::from(1u32);
254
255 let add = |a: &BigUint, b: &BigUint| -> BigUint {
256 let result =
257 syscall_tower_fp1_bls12381_add_impl(big_uint_into_bytes(a), big_uint_into_bytes(b))
258 .unwrap();
259 BigUint::from_bytes_le(&result)
260 };
261 let sub = |a: &BigUint, b: &BigUint| -> BigUint {
262 let result =
263 syscall_tower_fp1_bls12381_sub_impl(big_uint_into_bytes(a), big_uint_into_bytes(b))
264 .unwrap();
265 BigUint::from_bytes_le(&result)
266 };
267 let mul = |a: &BigUint, b: &BigUint| -> BigUint {
268 let result =
269 syscall_tower_fp1_bls12381_mul_impl(big_uint_into_bytes(a), big_uint_into_bytes(b))
270 .unwrap();
271 BigUint::from_bytes_le(&result)
272 };
273
274 for _ in 0..10 {
275 let a = random_bigint::<48>(&modulus);
276 let b = random_bigint::<48>(&modulus);
277
278 let result = add(&a, &b) % &modulus;
280 assert_eq!((&a + &b) % &modulus, result);
281
282 let result = add(&a, &zero) % &modulus;
284 assert_eq!((&a + &zero) % &modulus, result);
285
286 let expected_sub = if a < b {
288 ((&a + &modulus) - &b) % &modulus
289 } else {
290 (&a - &b) % &modulus
291 };
292 let result = sub(&a, &b) % &modulus;
293 assert_eq!(expected_sub, result);
294
295 let result = sub(&a, &zero) % &modulus;
297 assert_eq!((&a + &modulus - &zero) % &modulus, result);
298
299 let result = mul(&a, &b) % &modulus;
301 assert_eq!((&a * &b) % &modulus, result);
302
303 let result = mul(&a, &one) % &modulus;
305 assert_eq!((&a * &one) % &modulus, result);
306
307 let result = &mul(&a, &zero) % &modulus;
309 assert_eq!((&a * &zero) % &modulus, result,);
310 }
311 }
312
313 #[test]
314 #[should_panic(
315 expected = "called `Result::unwrap()` on an `Err` value: MalformedBuiltinParams"
316 )]
317 fn test_bls12381_overflow_cant_happen() {
318 let m = BigUint::from_str("4002409555221667393417789825735904156556882819939007885332058136124031650490837864442687629129015664037894272559787").unwrap();
319 let a = BigUint::from_str("3001807166416250545063342369301928117417662114954255913999043602093023737868128398332015721846761748028420704419841").unwrap();
320 let b = BigUint::from_str("12007228665665002180253369477207712469670648459817023655996174408372094951472513593328062887387046992113682817679361").unwrap();
321 assert!(a < m && b > m);
322
323 let result =
324 syscall_tower_fp1_bls12381_sub_impl(big_uint_into_bytes(&a), big_uint_into_bytes(&b))
325 .unwrap();
326 let result = BigUint::from_bytes_le(&result);
327
328 let expected_sub = ((&a + &m) - &b % &m) % &m;
329 assert_eq!(expected_sub, result);
330 }
331}