Skip to main content

fluentbase_runtime/syscall_handler/host/
write_fd.rs

1/// This file is fully copied from SP1's core executor: sp1/crates/core/executor/src/hook.rs
2///
3/// But with some modifications:
4/// 1. We don't support deprecated SP1 hooks
5/// 2. We replace Vec<Vec<u8>> with Vec<u8> to return linear data
6/// 3. Instead of panic, we return `ExitCode::MalformedBuiltinParams`
7///
8/// The rest here we kept as is w/o any modifications.
9/// P.S: Because of changes we applied upper we're not able to reuse their crate,
10///  also it requires having `HookEnv` context we don't have
11///
12use crate::{syscall_handler::syscall_process_exit_code, RuntimeContext};
13use fluentbase_types::{
14    fd::{
15        FD_BLS12_381_INVERSE, FD_BLS12_381_SQRT, FD_ECRECOVER_HOOK, FD_ED_DECOMPRESS, FD_FP_INV,
16        FD_FP_SQRT, FD_RSA_MUL_MOD,
17    },
18    ExitCode,
19};
20use rwasm::{StoreTr, TrapCode, Value};
21use sp1_curves::{
22    edwards::ed25519::{ed25519_sqrt, Ed25519BaseField},
23    params::FieldParameters,
24    BigUint, Integer, One,
25};
26
27pub fn syscall_write_fd_handler(
28    caller: &mut impl StoreTr<RuntimeContext>,
29    params: &[Value],
30    _result: &mut [Value],
31) -> Result<(), TrapCode> {
32    let (fd, slice_ptr, slice_len) = (
33        params[0].i32().unwrap() as u32,
34        params[1].i32().unwrap() as u32,
35        params[2].i32().unwrap() as u32,
36    );
37    let mut input = vec![0u8; slice_len as usize];
38    caller.memory_read(slice_ptr as usize, &mut input)?;
39    syscall_write_fd_impl(caller.data_mut(), fd, &input)
40        .map_err(|err| syscall_process_exit_code(caller, err))?;
41    Ok(())
42}
43
44pub fn syscall_write_fd_impl(
45    ctx: &mut RuntimeContext,
46    fd: u32,
47    input: &[u8],
48) -> Result<(), ExitCode> {
49    let output = match fd {
50        FD_ECRECOVER_HOOK => hook_ecrecover(input),
51        FD_ED_DECOMPRESS => hook_ed_decompress(input),
52        FD_RSA_MUL_MOD => hook_rsa_mul_mod(input),
53        FD_BLS12_381_SQRT => bls::hook_bls12_381_sqrt(input),
54        FD_BLS12_381_INVERSE => bls::hook_bls12_381_inverse(input),
55        FD_FP_SQRT => fp_ops::hook_fp_sqrt(input),
56        FD_FP_INV => fp_ops::hook_fp_inverse(input),
57        _ => return Ok(()),
58    }?;
59    ctx.execution_result.return_data = output;
60    Ok(())
61}
62
63/// The hook for the `ecrecover` patches.
64///
65/// The input should be of the form [(`curve_id_u8` | `r_is_y_odd_u8` << 7) || `r` || `alpha`]
66/// where:
67/// * `curve_id` is 1 for secp256k1 and 2 for secp256r1
68/// * `r_is_y_odd` is 0 if r is even and 1 if r is odd
69/// * r is the x-coordinate of the point, which should be 32 bytes,
70/// * alpha := r * r * r * (a * r) + b, which should be 32 bytes.
71///
72/// Returns vec![vec![1], `y`, `r_inv`] if the point is decompressable
73/// and vec![vec![0],`nqr_hint`] if not.
74fn hook_ecrecover(buf: &[u8]) -> Result<Vec<u8>, ExitCode> {
75    if buf.len() != 65 {
76        return Err(ExitCode::MalformedBuiltinParams);
77    }
78
79    let curve_id = buf[0] & 0b0111_1111;
80    let r_is_y_odd = buf[0] & 0b1000_0000 != 0;
81
82    let r_bytes: [u8; 32] = buf[1..33].try_into().unwrap();
83    let alpha_bytes: [u8; 32] = buf[33..65].try_into().unwrap();
84
85    Ok(match curve_id {
86        1 => ecrecover::handle_secp256k1(r_bytes, alpha_bytes, r_is_y_odd),
87        2 => ecrecover::handle_secp256r1(r_bytes, alpha_bytes, r_is_y_odd),
88        _ => return Err(ExitCode::MalformedBuiltinParams),
89    })
90}
91
92mod ecrecover {
93    use sp1_curves::{k256, p256};
94
95    /// The non-quadratic residue for the curve for secp256k1 and secp256r1.
96    const NQR: [u8; 32] = {
97        let mut nqr = [0; 32];
98        nqr[31] = 3;
99        nqr
100    };
101
102    pub(super) fn handle_secp256k1(r: [u8; 32], alpha: [u8; 32], r_y_is_odd: bool) -> Vec<u8> {
103        use k256::{
104            elliptic_curve::ff::PrimeField, FieldElement as K256FieldElement, Scalar as K256Scalar,
105        };
106
107        let r = K256FieldElement::from_bytes(r.as_ref().into()).unwrap();
108        debug_assert!(!bool::from(r.is_zero()), "r should not be zero");
109
110        let alpha = K256FieldElement::from_bytes(alpha.as_ref().into()).unwrap();
111        assert!(!bool::from(alpha.is_zero()), "alpha should not be zero");
112
113        // nomralize the y-coordinate always to be consistent.
114        if let Some(mut y_coord) = alpha.sqrt().into_option().map(|y| y.normalize()) {
115            let r = K256Scalar::from_repr(r.to_bytes()).unwrap();
116            let r_inv = r.invert().expect("Non zero r scalar");
117
118            if r_y_is_odd != bool::from(y_coord.is_odd()) {
119                y_coord = y_coord.negate(1);
120                y_coord = y_coord.normalize();
121            }
122
123            let mut result = vec![0x1];
124            result.extend_from_slice(&y_coord.to_bytes());
125            result.extend_from_slice(&r_inv.to_bytes());
126            result
127        } else {
128            let nqr_field = K256FieldElement::from_bytes(NQR.as_ref().into()).unwrap();
129            let qr = alpha * nqr_field;
130            let root = qr
131                .sqrt()
132                .expect("if alpha is not a square, then qr should be a square");
133            let mut result = vec![0x0];
134            result.extend_from_slice(&root.to_bytes());
135            result
136        }
137    }
138
139    pub(super) fn handle_secp256r1(r: [u8; 32], alpha: [u8; 32], r_y_is_odd: bool) -> Vec<u8> {
140        use p256::{
141            elliptic_curve::ff::PrimeField, FieldElement as P256FieldElement, Scalar as P256Scalar,
142        };
143
144        let r = P256FieldElement::from_bytes(r.as_ref().into()).unwrap();
145        debug_assert!(!bool::from(r.is_zero()), "r should not be zero");
146
147        let alpha = P256FieldElement::from_bytes(alpha.as_ref().into()).unwrap();
148        debug_assert!(!bool::from(alpha.is_zero()), "alpha should not be zero");
149
150        if let Some(mut y_coord) = alpha.sqrt().into_option() {
151            let r = P256Scalar::from_repr(r.to_bytes()).unwrap();
152            let r_inv = r.invert().expect("Non zero r scalar");
153
154            if r_y_is_odd != bool::from(y_coord.is_odd()) {
155                y_coord = -y_coord;
156            }
157
158            let mut result = vec![0x1];
159            result.extend_from_slice(&y_coord.to_bytes());
160            result.extend_from_slice(&r_inv.to_bytes());
161            result
162        } else {
163            let nqr_field = P256FieldElement::from_bytes(NQR.as_ref().into()).unwrap();
164            let qr = alpha * nqr_field;
165            let root = qr
166                .sqrt()
167                .expect("if alpha is not a square, then qr should be a square");
168            let mut result = vec![0x0];
169            result.extend_from_slice(&root.to_bytes());
170            result
171        }
172    }
173}
174
175/// Checks if a compressed Edwards point can be decompressed.
176///
177/// # Arguments
178/// * `env` - The environment in which the hook is invoked.
179/// * `buf` - The buffer containing the compressed Edwards point.
180///    - The compressed Edwards point is 32 bytes.
181///    - The high bit of the last byte is the sign bit.
182///
183/// Returns vec![vec![1]] if the point is decompressable.
184/// Returns vec![vec![0], `v_inv`, `nqr_hint`] if the point is not decompressable.
185///
186/// WARNING: This function merely hints at the validity of the compressed point. These values must
187/// be constrained by the zkVM for correctness.
188pub fn hook_ed_decompress(buf: &[u8]) -> Result<Vec<u8>, ExitCode> {
189    const NQR_CURVE_25519: u8 = 2;
190    let modulus = Ed25519BaseField::modulus();
191
192    let mut bytes: [u8; 32] = buf[..32].try_into().unwrap();
193    // Mask the sign bit.
194    bytes[31] &= 0b0111_1111;
195
196    // The AIR asserts canon inputs, so hint here if it cant be satisfied.
197    let y = BigUint::from_bytes_le(&bytes);
198    if y >= modulus {
199        return Ok(vec![0u8]);
200    }
201
202    let v = BigUint::from_bytes_le(&buf[32..]);
203    // This is computed as dy^2 - 1
204    // so it should always be in the field.
205    if v >= modulus {
206        return Err(ExitCode::MalformedBuiltinParams);
207    }
208
209    // For a point to be decompressable, (yy - 1) / (yy * d + 1) must be a quadratic residue.
210    let v_inv = v.modpow(&(&modulus - BigUint::from(2u64)), &modulus);
211    let u = (&y * &y + &modulus - BigUint::one()) % &modulus;
212    let u_div_v = (&u * &v_inv) % &modulus;
213
214    // Note: Our sqrt impl doesnt care about canon representation,
215    // however we have already checked that were less than the modulus.
216    if ed25519_sqrt(&u_div_v).is_some() {
217        return Ok(vec![0x1]);
218    }
219    let qr = (u_div_v * NQR_CURVE_25519) % &modulus;
220    let root = ed25519_sqrt(&qr).unwrap();
221
222    // Pad the results, since this may not be a full 32 bytes.
223    let v_inv_bytes = v_inv.to_bytes_le();
224    let mut v_inv_padded = [0_u8; 32];
225    v_inv_padded[..v_inv_bytes.len()].copy_from_slice(&v_inv.to_bytes_le());
226
227    let root_bytes = root.to_bytes_le();
228    let mut root_padded = [0_u8; 32];
229    root_padded[..root_bytes.len()].copy_from_slice(&root.to_bytes_le());
230
231    let mut result = vec![0x0];
232    result.extend_from_slice(&v_inv_padded);
233    result.extend_from_slice(&root_padded);
234    Ok(result)
235}
236
237/// Given the product of some 256-byte numbers and a modulus, this function does a modular
238/// reduction and hints back the values to the vm in order to constrain it.
239///
240/// # Arguments
241///
242/// * `env` - The environment in which the hook is invoked.
243/// * `buf` - The buffer containing the le bytes of the 512 byte product and the 256 byte modulus.
244///
245/// Returns The le bytes of the product % modulus (512 bytes)
246/// and the quotient floor(product/modulus) (256 bytes).
247///
248/// WANRING: This function is used to perform a modular reduction outside of the zkVM context.
249/// These values must be constrained by the zkVM for correctness.
250pub fn hook_rsa_mul_mod(buf: &[u8]) -> Result<Vec<u8>, ExitCode> {
251    if buf.len() != 256 + 256 + 256 {
252        return Err(ExitCode::MalformedBuiltinParams);
253    }
254
255    let prod: &[u8; 512] = buf[..512].try_into().unwrap();
256    let m: &[u8; 256] = buf[512..].try_into().unwrap();
257
258    let prod = BigUint::from_bytes_le(prod);
259    let m = BigUint::from_bytes_le(m);
260
261    let (q, rem) = prod.div_rem(&m);
262
263    let mut rem = rem.to_bytes_le();
264    rem.resize(256, 0);
265
266    let mut q = q.to_bytes_le();
267    q.resize(256, 0);
268
269    let mut result = rem;
270    result.extend_from_slice(&q);
271    Ok(result)
272}
273
274mod bls {
275    use super::{pad_to_be, BigUint};
276    use fluentbase_types::ExitCode;
277    use sp1_curves::{params::FieldParameters, weierstrass::bls12_381::Bls12381BaseField, Zero};
278
279    /// A non-quadratic residue for the `12_381` base field in big endian.
280    pub const NQR_BLS12_381: [u8; 48] = {
281        let mut nqr = [0; 48];
282        nqr[47] = 2;
283        nqr
284    };
285
286    /// The base field modulus for the `12_381` curve, in little endian.
287    pub const BLS12_381_MODULUS: &[u8] = Bls12381BaseField::MODULUS;
288
289    /// Given a field element, in big endian, this function computes the square root.
290    ///
291    /// - If the field element is the additive identity, this function returns `vec![vec![1],
292    ///   vec![0; 48]]`.
293    /// - If the field element is a quadratic residue, this function returns `vec![vec![1],
294    ///   vec![sqrt(fe)]  ]`.
295    /// - If the field element (fe) is not a quadratic residue, this function returns `vec![vec![0],
296    ///   vec![sqrt(``NQR_BLS12_381`` * fe)]]`.
297    pub fn hook_bls12_381_sqrt(buf: &[u8]) -> Result<Vec<u8>, ExitCode> {
298        let field_element = BigUint::from_bytes_be(&buf[..48]);
299
300        // This should be checked in the VM as its easier than dispatching a hook call.
301        // But for completeness we include this happy path also.
302        if field_element.is_zero() {
303            let mut result = vec![1];
304            result.resize(48 + 1, 0);
305            return Ok(result);
306        }
307
308        let modulus = BigUint::from_bytes_le(BLS12_381_MODULUS);
309
310        // Since `BLS12_381_MODULUS` == 3 mod 4,. we can use shanks methods.
311        // This means we only need to exponentiate by `(modulus + 1) / 4`.
312        let exp = (&modulus + BigUint::from(1u64)) / BigUint::from(4u64);
313        let sqrt = field_element.modpow(&exp, &modulus);
314
315        // Shanks methods only works if the field element is a quadratic residue.
316        // So we need to check if the square of the sqrt is equal to the field element.
317        let square = (&sqrt * &sqrt) % &modulus;
318        if square != field_element {
319            let nqr = BigUint::from_bytes_be(&NQR_BLS12_381);
320            let qr = (&nqr * &field_element) % &modulus;
321
322            // By now, the product of two non-quadratic residues is a quadratic residue.
323            // So we can use shanks methods again to get its square root.
324            //
325            // We pass this root back to the VM to constrain the "failure" case.
326            let root = qr.modpow(&exp, &modulus);
327
328            debug_assert!(
329                (&root * &root) % &modulus == qr,
330                "NQR sanity check failed, this is a bug."
331            );
332
333            let mut result = vec![0];
334            result.extend(pad_to_be(&root, 48));
335            return Ok(result);
336        }
337
338        let mut result = vec![1];
339        result.extend(pad_to_be(&sqrt, 48));
340        Ok(result)
341    }
342
343    /// Given a field element, in big endian, this function computes the inverse.
344    ///
345    /// This functions will panic if the additive identity is passed in.
346    pub fn hook_bls12_381_inverse(buf: &[u8]) -> Result<Vec<u8>, ExitCode> {
347        let field_element = BigUint::from_bytes_be(&buf[..48]);
348
349        // Zero is not invertible, and we dont want to have to return a status from here.
350        if field_element.is_zero() {
351            return Err(ExitCode::MalformedBuiltinParams);
352        }
353
354        let modulus = BigUint::from_bytes_le(BLS12_381_MODULUS);
355
356        // Compute the inverse using Fermat's little theorem, ie, a^(p-2) = a^-1 mod p.
357        let inverse = field_element.modpow(&(&modulus - BigUint::from(2u64)), &modulus);
358
359        Ok(pad_to_be(&inverse, 48))
360    }
361}
362
363/// Pads a big uint to the given length in big endian.
364fn pad_to_be(val: &BigUint, len: usize) -> Vec<u8> {
365    // First take the byes in little endian
366    let mut bytes = val.to_bytes_le();
367    // Resize so we get the full padding correctly.
368    if len > bytes.len() {
369        bytes.resize(len, 0);
370    }
371    // Convert back to big endian.
372    bytes.reverse();
373
374    bytes
375}
376
377mod fp_ops {
378    use super::{pad_to_be, BigUint, One};
379    use fluentbase_types::ExitCode;
380    use sp1_curves::Zero;
381
382    /// Compute the inverse of a field element.
383    ///
384    /// # Arguments:
385    /// * `buf` - The buffer containing the data needed to compute the inverse.
386    ///     - [ len || Element || Modulus ]
387    ///     - len is the u32 length of the element and modulus in big endian.
388    ///     - Element is the field element to compute the inverse of, interpreted as a big endian
389    ///       integer of `len` bytes.
390    ///
391    /// # Returns:
392    /// A single 32 byte vector containing the inverse.
393    ///
394    /// # Panics:
395    /// - If the buffer length is not valid.
396    /// - If the element is zero.
397    pub fn hook_fp_inverse(buf: &[u8]) -> Result<Vec<u8>, ExitCode> {
398        let len: usize = u32::from_be_bytes(buf[0..4].try_into().unwrap()) as usize;
399
400        if buf.len() != 4 + 2 * len {
401            return Err(ExitCode::MalformedBuiltinParams);
402        }
403
404        let buf = &buf[4..];
405        let element = BigUint::from_bytes_be(&buf[..len]);
406        let modulus = BigUint::from_bytes_be(&buf[len..2 * len]);
407
408        if element.is_zero() {
409            return Err(ExitCode::MalformedBuiltinParams);
410        }
411
412        let inverse = element.modpow(&(&modulus - BigUint::from(2u64)), &modulus);
413
414        Ok(pad_to_be(&inverse, len))
415    }
416
417    /// Compute the square root of a field element.
418    ///
419    /// # Arguments:
420    /// * `buf` - The buffer containing the data needed to compute the square root.
421    ///     - [ len || Element || Modulus || NQR ]
422    ///     - len is the length of the element, modulus, and nqr in big endian.
423    ///     - Element is the field element to compute the square root of, interpreted as a big
424    ///       endian integer of `len` bytes.
425    ///     - Modulus is the modulus of the field, interpreted as a big endian integer of `len`
426    ///       bytes.
427    ///     - NQR is the non-quadratic residue of the field, interpreted as a big endian integer of
428    ///       `len` bytes.
429    ///
430    /// # Assumptions
431    /// - NQR is a non-quadratic residue of the field.
432    ///
433    /// # Returns:
434    /// [ `status_u8` || `root_bytes` ]
435    ///
436    /// If the status is 0, this is the root of NQR * element.
437    /// If the status is 1, this is the root of element.
438    ///
439    /// # Panics:
440    /// - If the buffer length is not valid.
441    /// - If the element is not less than the modulus.
442    /// - If the nqr is not less than the modulus.
443    /// - If the element is zero.
444    pub fn hook_fp_sqrt(buf: &[u8]) -> Result<Vec<u8>, ExitCode> {
445        let len: usize = u32::from_be_bytes(buf[0..4].try_into().unwrap()) as usize;
446
447        if buf.len() != 4 + 3 * len {
448            return Err(ExitCode::MalformedBuiltinParams);
449        }
450
451        let buf = &buf[4..];
452        let element = BigUint::from_bytes_be(&buf[..len]);
453        let modulus = BigUint::from_bytes_be(&buf[len..2 * len]);
454        let nqr = BigUint::from_bytes_be(&buf[2 * len..3 * len]);
455
456        if element > modulus || nqr > modulus {
457            return Err(ExitCode::MalformedBuiltinParams);
458        }
459
460        // The sqrt of zero is zero.
461        if element.is_zero() {
462            let mut result = vec![1];
463            result.resize(len + 1, 0);
464            return Ok(result);
465        }
466
467        // Compute the square root of the element using the general Tonelli-Shanks algorithm.
468        // The implementation can be used for any field as it is field-agnostic.
469        if let Some(root) = sqrt_fp(&element, &modulus, &nqr) {
470            let mut result = vec![1];
471            result.extend(pad_to_be(&root, len));
472            Ok(result)
473        } else {
474            let qr = (&nqr * &element) % &modulus;
475            let root = sqrt_fp(&qr, &modulus, &nqr).unwrap();
476            let mut result = vec![0];
477            result.extend(pad_to_be(&root, len));
478            Ok(result)
479        }
480    }
481
482    /// Compute the square root of a field element for some modulus.
483    ///
484    /// Requires a known non-quadratic residue of the field.
485    fn sqrt_fp(element: &BigUint, modulus: &BigUint, nqr: &BigUint) -> Option<BigUint> {
486        // If the prime field is of the form p = 3 mod 4, and `x` is a quadratic residue modulo `p`,
487        // then one square root of `x` is given by `x^(p+1 / 4) mod p`.
488        if modulus % BigUint::from(4u64) == BigUint::from(3u64) {
489            let maybe_root = element.modpow(
490                &((modulus + BigUint::from(1u64)) / BigUint::from(4u64)),
491                modulus,
492            );
493
494            return Some(maybe_root).filter(|root| root * root % modulus == *element);
495        }
496
497        tonelli_shanks(element, modulus, nqr)
498    }
499
500    /// Compute the square root of a field element using the Tonelli-Shanks algorithm.
501    ///
502    /// # Arguments:
503    /// * `element` - The field element to compute the square root of.
504    /// * `modulus` - The modulus of the field.
505    /// * `nqr` - The non-quadratic residue of the field.
506    ///
507    /// # Assumptions:
508    /// - The element is a quadratic residue modulo the modulus.
509    ///
510    /// Ref: <https://en.wikipedia.org/wiki/Tonelli%E2%80%93Shanks_algorithm>
511    #[allow(clippy::many_single_char_names)]
512    fn tonelli_shanks(element: &BigUint, modulus: &BigUint, nqr: &BigUint) -> Option<BigUint> {
513        // First, compute the Legendre symbol of the element.
514        // If the symbol is not 1, then the element is not a quadratic residue.
515        if legendre_symbol(element, modulus) != BigUint::one() {
516            return None;
517        }
518
519        // Find the values of Q and S such that modulus - 1 = Q * 2^S.
520        let mut s = BigUint::zero();
521        let mut q = modulus - BigUint::one();
522        while &q % &BigUint::from(2u64) == BigUint::zero() {
523            s += BigUint::from(1u64);
524            q /= BigUint::from(2u64);
525        }
526
527        let z = nqr;
528        let mut c = z.modpow(&q, modulus);
529        let mut r = element.modpow(&((&q + BigUint::from(1u64)) / BigUint::from(2u64)), modulus);
530        let mut t = element.modpow(&q, modulus);
531        let mut m = s;
532
533        while t != BigUint::one() {
534            let mut i = BigUint::zero();
535            let mut tt = t.clone();
536            while tt != BigUint::one() {
537                tt = &tt * &tt % modulus;
538                i += BigUint::from(1u64);
539
540                if i == m {
541                    return None;
542                }
543            }
544
545            let b_pow =
546                BigUint::from(2u64).pow((&m - &i - BigUint::from(1u64)).try_into().unwrap());
547            let b = c.modpow(&b_pow, modulus);
548
549            r = &r * &b % modulus;
550            c = &b * &b % modulus;
551            t = &t * &c % modulus;
552            m = i;
553        }
554
555        Some(r)
556    }
557
558    /// Compute the Legendre symbol of a field element.
559    ///
560    /// This indicates if the element is a quadratic in the prime field.
561    ///
562    /// Ref: <https://en.wikipedia.org/wiki/Legendre_symbol>
563    fn legendre_symbol(element: &BigUint, modulus: &BigUint) -> BigUint {
564        assert!(!element.is_zero(), "FpOp: Legendre symbol of zero called.");
565
566        element.modpow(&((modulus - BigUint::one()) / BigUint::from(2u64)), modulus)
567    }
568
569    #[cfg(test)]
570    mod test {
571        use super::*;
572        use std::str::FromStr;
573
574        #[test]
575        fn test_legendre_symbol() {
576            // The modulus of the secp256k1 base field.
577            let modulus = BigUint::from_str(
578                "115792089237316195423570985008687907853269984665640564039457584007908834671663",
579            )
580            .unwrap();
581            let neg_1 = &modulus - BigUint::one();
582
583            let fixtures = [
584                (BigUint::from(4u64), BigUint::from(1u64)),
585                (BigUint::from(2u64), BigUint::from(1u64)),
586                (BigUint::from(3u64), neg_1.clone()),
587            ];
588
589            for (element, expected) in fixtures {
590                let result = legendre_symbol(&element, &modulus);
591                assert_eq!(result, expected);
592            }
593        }
594
595        #[test]
596        fn test_tonelli_shanks() {
597            // The modulus of the secp256k1 base field.
598            let p = BigUint::from_str(
599                "115792089237316195423570985008687907853269984665640564039457584007908834671663",
600            )
601            .unwrap();
602
603            let nqr = BigUint::from_str("3").unwrap();
604
605            let large_element = &p - BigUint::from(u16::MAX);
606            let square = &large_element * &large_element % &p;
607
608            let fixtures = [
609                (BigUint::from(2u64), true),
610                (BigUint::from(3u64), false),
611                (BigUint::from(4u64), true),
612                (square, true),
613            ];
614
615            for (element, expected) in fixtures {
616                let result = tonelli_shanks(&element, &p, &nqr);
617                if expected {
618                    assert!(result.is_some());
619
620                    let result = result.unwrap();
621                    assert!((&result * &result) % &p == element);
622                } else {
623                    assert!(result.is_none());
624                }
625            }
626        }
627    }
628}