fluentbase_runtime/syscall_handler/host/
write_fd.rs1use crate::{syscall_handler::syscall_process_exit_code, RuntimeContext};
13use fluentbase_types::{
14 fd::{
15 FD_BLS12_381_INVERSE, FD_BLS12_381_SQRT, FD_ECRECOVER_HOOK, FD_ED_DECOMPRESS, FD_FP_INV,
16 FD_FP_SQRT, FD_RSA_MUL_MOD,
17 },
18 ExitCode,
19};
20use rwasm::{StoreTr, TrapCode, Value};
21use sp1_curves::{
22 edwards::ed25519::{ed25519_sqrt, Ed25519BaseField},
23 params::FieldParameters,
24 BigUint, Integer, One,
25};
26
27pub fn syscall_write_fd_handler(
28 caller: &mut impl StoreTr<RuntimeContext>,
29 params: &[Value],
30 _result: &mut [Value],
31) -> Result<(), TrapCode> {
32 let (fd, slice_ptr, slice_len) = (
33 params[0].i32().unwrap() as u32,
34 params[1].i32().unwrap() as u32,
35 params[2].i32().unwrap() as u32,
36 );
37 let mut input = vec![0u8; slice_len as usize];
38 caller.memory_read(slice_ptr as usize, &mut input)?;
39 syscall_write_fd_impl(caller.data_mut(), fd, &input)
40 .map_err(|err| syscall_process_exit_code(caller, err))?;
41 Ok(())
42}
43
44pub fn syscall_write_fd_impl(
45 ctx: &mut RuntimeContext,
46 fd: u32,
47 input: &[u8],
48) -> Result<(), ExitCode> {
49 let output = match fd {
50 FD_ECRECOVER_HOOK => hook_ecrecover(input),
51 FD_ED_DECOMPRESS => hook_ed_decompress(input),
52 FD_RSA_MUL_MOD => hook_rsa_mul_mod(input),
53 FD_BLS12_381_SQRT => bls::hook_bls12_381_sqrt(input),
54 FD_BLS12_381_INVERSE => bls::hook_bls12_381_inverse(input),
55 FD_FP_SQRT => fp_ops::hook_fp_sqrt(input),
56 FD_FP_INV => fp_ops::hook_fp_inverse(input),
57 _ => return Ok(()),
58 }?;
59 ctx.execution_result.return_data = output;
60 Ok(())
61}
62
63fn hook_ecrecover(buf: &[u8]) -> Result<Vec<u8>, ExitCode> {
75 if buf.len() != 65 {
76 return Err(ExitCode::MalformedBuiltinParams);
77 }
78
79 let curve_id = buf[0] & 0b0111_1111;
80 let r_is_y_odd = buf[0] & 0b1000_0000 != 0;
81
82 let r_bytes: [u8; 32] = buf[1..33].try_into().unwrap();
83 let alpha_bytes: [u8; 32] = buf[33..65].try_into().unwrap();
84
85 Ok(match curve_id {
86 1 => ecrecover::handle_secp256k1(r_bytes, alpha_bytes, r_is_y_odd),
87 2 => ecrecover::handle_secp256r1(r_bytes, alpha_bytes, r_is_y_odd),
88 _ => return Err(ExitCode::MalformedBuiltinParams),
89 })
90}
91
92mod ecrecover {
93 use sp1_curves::{k256, p256};
94
95 const NQR: [u8; 32] = {
97 let mut nqr = [0; 32];
98 nqr[31] = 3;
99 nqr
100 };
101
102 pub(super) fn handle_secp256k1(r: [u8; 32], alpha: [u8; 32], r_y_is_odd: bool) -> Vec<u8> {
103 use k256::{
104 elliptic_curve::ff::PrimeField, FieldElement as K256FieldElement, Scalar as K256Scalar,
105 };
106
107 let r = K256FieldElement::from_bytes(r.as_ref().into()).unwrap();
108 debug_assert!(!bool::from(r.is_zero()), "r should not be zero");
109
110 let alpha = K256FieldElement::from_bytes(alpha.as_ref().into()).unwrap();
111 assert!(!bool::from(alpha.is_zero()), "alpha should not be zero");
112
113 if let Some(mut y_coord) = alpha.sqrt().into_option().map(|y| y.normalize()) {
115 let r = K256Scalar::from_repr(r.to_bytes()).unwrap();
116 let r_inv = r.invert().expect("Non zero r scalar");
117
118 if r_y_is_odd != bool::from(y_coord.is_odd()) {
119 y_coord = y_coord.negate(1);
120 y_coord = y_coord.normalize();
121 }
122
123 let mut result = vec![0x1];
124 result.extend_from_slice(&y_coord.to_bytes());
125 result.extend_from_slice(&r_inv.to_bytes());
126 result
127 } else {
128 let nqr_field = K256FieldElement::from_bytes(NQR.as_ref().into()).unwrap();
129 let qr = alpha * nqr_field;
130 let root = qr
131 .sqrt()
132 .expect("if alpha is not a square, then qr should be a square");
133 let mut result = vec![0x0];
134 result.extend_from_slice(&root.to_bytes());
135 result
136 }
137 }
138
139 pub(super) fn handle_secp256r1(r: [u8; 32], alpha: [u8; 32], r_y_is_odd: bool) -> Vec<u8> {
140 use p256::{
141 elliptic_curve::ff::PrimeField, FieldElement as P256FieldElement, Scalar as P256Scalar,
142 };
143
144 let r = P256FieldElement::from_bytes(r.as_ref().into()).unwrap();
145 debug_assert!(!bool::from(r.is_zero()), "r should not be zero");
146
147 let alpha = P256FieldElement::from_bytes(alpha.as_ref().into()).unwrap();
148 debug_assert!(!bool::from(alpha.is_zero()), "alpha should not be zero");
149
150 if let Some(mut y_coord) = alpha.sqrt().into_option() {
151 let r = P256Scalar::from_repr(r.to_bytes()).unwrap();
152 let r_inv = r.invert().expect("Non zero r scalar");
153
154 if r_y_is_odd != bool::from(y_coord.is_odd()) {
155 y_coord = -y_coord;
156 }
157
158 let mut result = vec![0x1];
159 result.extend_from_slice(&y_coord.to_bytes());
160 result.extend_from_slice(&r_inv.to_bytes());
161 result
162 } else {
163 let nqr_field = P256FieldElement::from_bytes(NQR.as_ref().into()).unwrap();
164 let qr = alpha * nqr_field;
165 let root = qr
166 .sqrt()
167 .expect("if alpha is not a square, then qr should be a square");
168 let mut result = vec![0x0];
169 result.extend_from_slice(&root.to_bytes());
170 result
171 }
172 }
173}
174
175pub fn hook_ed_decompress(buf: &[u8]) -> Result<Vec<u8>, ExitCode> {
189 const NQR_CURVE_25519: u8 = 2;
190 let modulus = Ed25519BaseField::modulus();
191
192 let mut bytes: [u8; 32] = buf[..32].try_into().unwrap();
193 bytes[31] &= 0b0111_1111;
195
196 let y = BigUint::from_bytes_le(&bytes);
198 if y >= modulus {
199 return Ok(vec![0u8]);
200 }
201
202 let v = BigUint::from_bytes_le(&buf[32..]);
203 if v >= modulus {
206 return Err(ExitCode::MalformedBuiltinParams);
207 }
208
209 let v_inv = v.modpow(&(&modulus - BigUint::from(2u64)), &modulus);
211 let u = (&y * &y + &modulus - BigUint::one()) % &modulus;
212 let u_div_v = (&u * &v_inv) % &modulus;
213
214 if ed25519_sqrt(&u_div_v).is_some() {
217 return Ok(vec![0x1]);
218 }
219 let qr = (u_div_v * NQR_CURVE_25519) % &modulus;
220 let root = ed25519_sqrt(&qr).unwrap();
221
222 let v_inv_bytes = v_inv.to_bytes_le();
224 let mut v_inv_padded = [0_u8; 32];
225 v_inv_padded[..v_inv_bytes.len()].copy_from_slice(&v_inv.to_bytes_le());
226
227 let root_bytes = root.to_bytes_le();
228 let mut root_padded = [0_u8; 32];
229 root_padded[..root_bytes.len()].copy_from_slice(&root.to_bytes_le());
230
231 let mut result = vec![0x0];
232 result.extend_from_slice(&v_inv_padded);
233 result.extend_from_slice(&root_padded);
234 Ok(result)
235}
236
237pub fn hook_rsa_mul_mod(buf: &[u8]) -> Result<Vec<u8>, ExitCode> {
251 if buf.len() != 256 + 256 + 256 {
252 return Err(ExitCode::MalformedBuiltinParams);
253 }
254
255 let prod: &[u8; 512] = buf[..512].try_into().unwrap();
256 let m: &[u8; 256] = buf[512..].try_into().unwrap();
257
258 let prod = BigUint::from_bytes_le(prod);
259 let m = BigUint::from_bytes_le(m);
260
261 let (q, rem) = prod.div_rem(&m);
262
263 let mut rem = rem.to_bytes_le();
264 rem.resize(256, 0);
265
266 let mut q = q.to_bytes_le();
267 q.resize(256, 0);
268
269 let mut result = rem;
270 result.extend_from_slice(&q);
271 Ok(result)
272}
273
274mod bls {
275 use super::{pad_to_be, BigUint};
276 use fluentbase_types::ExitCode;
277 use sp1_curves::{params::FieldParameters, weierstrass::bls12_381::Bls12381BaseField, Zero};
278
279 pub const NQR_BLS12_381: [u8; 48] = {
281 let mut nqr = [0; 48];
282 nqr[47] = 2;
283 nqr
284 };
285
286 pub const BLS12_381_MODULUS: &[u8] = Bls12381BaseField::MODULUS;
288
289 pub fn hook_bls12_381_sqrt(buf: &[u8]) -> Result<Vec<u8>, ExitCode> {
298 let field_element = BigUint::from_bytes_be(&buf[..48]);
299
300 if field_element.is_zero() {
303 let mut result = vec![1];
304 result.resize(48 + 1, 0);
305 return Ok(result);
306 }
307
308 let modulus = BigUint::from_bytes_le(BLS12_381_MODULUS);
309
310 let exp = (&modulus + BigUint::from(1u64)) / BigUint::from(4u64);
313 let sqrt = field_element.modpow(&exp, &modulus);
314
315 let square = (&sqrt * &sqrt) % &modulus;
318 if square != field_element {
319 let nqr = BigUint::from_bytes_be(&NQR_BLS12_381);
320 let qr = (&nqr * &field_element) % &modulus;
321
322 let root = qr.modpow(&exp, &modulus);
327
328 debug_assert!(
329 (&root * &root) % &modulus == qr,
330 "NQR sanity check failed, this is a bug."
331 );
332
333 let mut result = vec![0];
334 result.extend(pad_to_be(&root, 48));
335 return Ok(result);
336 }
337
338 let mut result = vec![1];
339 result.extend(pad_to_be(&sqrt, 48));
340 Ok(result)
341 }
342
343 pub fn hook_bls12_381_inverse(buf: &[u8]) -> Result<Vec<u8>, ExitCode> {
347 let field_element = BigUint::from_bytes_be(&buf[..48]);
348
349 if field_element.is_zero() {
351 return Err(ExitCode::MalformedBuiltinParams);
352 }
353
354 let modulus = BigUint::from_bytes_le(BLS12_381_MODULUS);
355
356 let inverse = field_element.modpow(&(&modulus - BigUint::from(2u64)), &modulus);
358
359 Ok(pad_to_be(&inverse, 48))
360 }
361}
362
363fn pad_to_be(val: &BigUint, len: usize) -> Vec<u8> {
365 let mut bytes = val.to_bytes_le();
367 if len > bytes.len() {
369 bytes.resize(len, 0);
370 }
371 bytes.reverse();
373
374 bytes
375}
376
377mod fp_ops {
378 use super::{pad_to_be, BigUint, One};
379 use fluentbase_types::ExitCode;
380 use sp1_curves::Zero;
381
382 pub fn hook_fp_inverse(buf: &[u8]) -> Result<Vec<u8>, ExitCode> {
398 let len: usize = u32::from_be_bytes(buf[0..4].try_into().unwrap()) as usize;
399
400 if buf.len() != 4 + 2 * len {
401 return Err(ExitCode::MalformedBuiltinParams);
402 }
403
404 let buf = &buf[4..];
405 let element = BigUint::from_bytes_be(&buf[..len]);
406 let modulus = BigUint::from_bytes_be(&buf[len..2 * len]);
407
408 if element.is_zero() {
409 return Err(ExitCode::MalformedBuiltinParams);
410 }
411
412 let inverse = element.modpow(&(&modulus - BigUint::from(2u64)), &modulus);
413
414 Ok(pad_to_be(&inverse, len))
415 }
416
417 pub fn hook_fp_sqrt(buf: &[u8]) -> Result<Vec<u8>, ExitCode> {
445 let len: usize = u32::from_be_bytes(buf[0..4].try_into().unwrap()) as usize;
446
447 if buf.len() != 4 + 3 * len {
448 return Err(ExitCode::MalformedBuiltinParams);
449 }
450
451 let buf = &buf[4..];
452 let element = BigUint::from_bytes_be(&buf[..len]);
453 let modulus = BigUint::from_bytes_be(&buf[len..2 * len]);
454 let nqr = BigUint::from_bytes_be(&buf[2 * len..3 * len]);
455
456 if element > modulus || nqr > modulus {
457 return Err(ExitCode::MalformedBuiltinParams);
458 }
459
460 if element.is_zero() {
462 let mut result = vec![1];
463 result.resize(len + 1, 0);
464 return Ok(result);
465 }
466
467 if let Some(root) = sqrt_fp(&element, &modulus, &nqr) {
470 let mut result = vec![1];
471 result.extend(pad_to_be(&root, len));
472 Ok(result)
473 } else {
474 let qr = (&nqr * &element) % &modulus;
475 let root = sqrt_fp(&qr, &modulus, &nqr).unwrap();
476 let mut result = vec![0];
477 result.extend(pad_to_be(&root, len));
478 Ok(result)
479 }
480 }
481
482 fn sqrt_fp(element: &BigUint, modulus: &BigUint, nqr: &BigUint) -> Option<BigUint> {
486 if modulus % BigUint::from(4u64) == BigUint::from(3u64) {
489 let maybe_root = element.modpow(
490 &((modulus + BigUint::from(1u64)) / BigUint::from(4u64)),
491 modulus,
492 );
493
494 return Some(maybe_root).filter(|root| root * root % modulus == *element);
495 }
496
497 tonelli_shanks(element, modulus, nqr)
498 }
499
500 #[allow(clippy::many_single_char_names)]
512 fn tonelli_shanks(element: &BigUint, modulus: &BigUint, nqr: &BigUint) -> Option<BigUint> {
513 if legendre_symbol(element, modulus) != BigUint::one() {
516 return None;
517 }
518
519 let mut s = BigUint::zero();
521 let mut q = modulus - BigUint::one();
522 while &q % &BigUint::from(2u64) == BigUint::zero() {
523 s += BigUint::from(1u64);
524 q /= BigUint::from(2u64);
525 }
526
527 let z = nqr;
528 let mut c = z.modpow(&q, modulus);
529 let mut r = element.modpow(&((&q + BigUint::from(1u64)) / BigUint::from(2u64)), modulus);
530 let mut t = element.modpow(&q, modulus);
531 let mut m = s;
532
533 while t != BigUint::one() {
534 let mut i = BigUint::zero();
535 let mut tt = t.clone();
536 while tt != BigUint::one() {
537 tt = &tt * &tt % modulus;
538 i += BigUint::from(1u64);
539
540 if i == m {
541 return None;
542 }
543 }
544
545 let b_pow =
546 BigUint::from(2u64).pow((&m - &i - BigUint::from(1u64)).try_into().unwrap());
547 let b = c.modpow(&b_pow, modulus);
548
549 r = &r * &b % modulus;
550 c = &b * &b % modulus;
551 t = &t * &c % modulus;
552 m = i;
553 }
554
555 Some(r)
556 }
557
558 fn legendre_symbol(element: &BigUint, modulus: &BigUint) -> BigUint {
564 assert!(!element.is_zero(), "FpOp: Legendre symbol of zero called.");
565
566 element.modpow(&((modulus - BigUint::one()) / BigUint::from(2u64)), modulus)
567 }
568
569 #[cfg(test)]
570 mod test {
571 use super::*;
572 use std::str::FromStr;
573
574 #[test]
575 fn test_legendre_symbol() {
576 let modulus = BigUint::from_str(
578 "115792089237316195423570985008687907853269984665640564039457584007908834671663",
579 )
580 .unwrap();
581 let neg_1 = &modulus - BigUint::one();
582
583 let fixtures = [
584 (BigUint::from(4u64), BigUint::from(1u64)),
585 (BigUint::from(2u64), BigUint::from(1u64)),
586 (BigUint::from(3u64), neg_1.clone()),
587 ];
588
589 for (element, expected) in fixtures {
590 let result = legendre_symbol(&element, &modulus);
591 assert_eq!(result, expected);
592 }
593 }
594
595 #[test]
596 fn test_tonelli_shanks() {
597 let p = BigUint::from_str(
599 "115792089237316195423570985008687907853269984665640564039457584007908834671663",
600 )
601 .unwrap();
602
603 let nqr = BigUint::from_str("3").unwrap();
604
605 let large_element = &p - BigUint::from(u16::MAX);
606 let square = &large_element * &large_element % &p;
607
608 let fixtures = [
609 (BigUint::from(2u64), true),
610 (BigUint::from(3u64), false),
611 (BigUint::from(4u64), true),
612 (square, true),
613 ];
614
615 for (element, expected) in fixtures {
616 let result = tonelli_shanks(&element, &p, &nqr);
617 if expected {
618 assert!(result.is_some());
619
620 let result = result.unwrap();
621 assert!((&result * &result) % &p == element);
622 } else {
623 assert!(result.is_none());
624 }
625 }
626 }
627 }
628}