use core::fmt::Debug;
use std::sync::{Arc, RwLock, RwLockWriteGuard};
use hashbrown::HashMap;
use sp1_curves::{
edwards::ed25519::ed25519_sqrt, params::FieldParameters, BigUint, Integer, One, Zero,
};
pub type BoxedHook<'a> = Arc<RwLock<dyn Hook + Send + Sync + 'a>>;
pub use sp1_primitives::consts::fd::*;
pub trait Hook {
fn invoke_hook(&mut self, env: HookEnv, buf: &[u8]) -> Vec<Vec<u8>>;
}
impl<F: FnMut(HookEnv, &[u8]) -> Vec<Vec<u8>>> Hook for F {
fn invoke_hook(&mut self, env: HookEnv, buf: &[u8]) -> Vec<Vec<u8>> {
self(env, buf)
}
}
pub fn hookify<'a>(
f: impl FnMut(HookEnv, &[u8]) -> Vec<Vec<u8>> + Send + Sync + 'a,
) -> BoxedHook<'a> {
Arc::new(RwLock::new(f))
}
#[derive(Clone)]
pub struct HookRegistry<'a> {
pub(crate) table: HashMap<u32, BoxedHook<'a>>,
}
impl<'a> HookRegistry<'a> {
#[must_use]
pub fn new() -> Self {
HookRegistry::default()
}
#[must_use]
pub fn empty() -> Self {
Self { table: HashMap::default() }
}
#[must_use]
pub fn get(&self, fd: u32) -> Option<RwLockWriteGuard<'_, dyn Hook + Send + Sync + 'a>> {
self.table.get(&fd).map(|x| x.write().unwrap())
}
}
impl Default for HookRegistry<'_> {
fn default() -> Self {
let table = HashMap::from([
(FD_ECRECOVER_HOOK, hookify(hook_ecrecover)),
(FD_EDDECOMPRESS, hookify(hook_ed_decompress)),
(FD_RSA_MUL_MOD, hookify(hook_rsa_mul_mod)),
(FD_BLS12_381_SQRT, hookify(bls::hook_bls12_381_sqrt)),
(FD_BLS12_381_INVERSE, hookify(bls::hook_bls12_381_inverse)),
(FD_FP_SQRT, hookify(fp_ops::hook_fp_sqrt)),
(FD_FP_INV, hookify(fp_ops::hook_fp_inverse)),
]);
Self { table }
}
}
impl Debug for HookRegistry<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut keys = self.table.keys().collect::<Vec<_>>();
keys.sort_unstable();
f.debug_struct("HookRegistry")
.field(
"table",
&format_args!("{{{} hooks registered at {:?}}}", self.table.len(), keys),
)
.finish()
}
}
pub struct HookEnv {}
#[must_use]
pub fn hook_ecrecover(_: HookEnv, buf: &[u8]) -> Vec<Vec<u8>> {
assert!(buf.len() == 64 + 1, "ecrecover should have length 65");
let curve_id = buf[0] & 0b0111_1111;
let r_is_y_odd = buf[0] & 0b1000_0000 != 0;
let r_bytes: [u8; 32] = buf[1..33].try_into().unwrap();
let alpha_bytes: [u8; 32] = buf[33..65].try_into().unwrap();
match curve_id {
1 => ecrecover::handle_secp256k1(r_bytes, alpha_bytes, r_is_y_odd),
2 => ecrecover::handle_secp256r1(r_bytes, alpha_bytes, r_is_y_odd),
_ => unimplemented!("Unsupported curve id: {}", curve_id),
}
}
mod ecrecover {
use sp1_curves::{k256, p256};
const NQR: [u8; 32] = {
let mut nqr = [0; 32];
nqr[31] = 3;
nqr
};
pub(super) fn handle_secp256k1(r: [u8; 32], alpha: [u8; 32], r_y_is_odd: bool) -> Vec<Vec<u8>> {
use k256::{
elliptic_curve::ff::PrimeField, FieldElement as K256FieldElement, Scalar as K256Scalar,
};
let r = K256FieldElement::from_bytes(&r.into()).unwrap();
debug_assert!(!bool::from(r.is_zero()), "r should not be zero");
let alpha = K256FieldElement::from_bytes(&alpha.into()).unwrap();
assert!(!bool::from(alpha.is_zero()), "alpha should not be zero");
if let Some(mut y_coord) = alpha.sqrt().into_option().map(|y| y.normalize()) {
let r = K256Scalar::from_repr(r.to_bytes()).unwrap();
let r_inv = r.invert().expect("Non zero r scalar");
if r_y_is_odd != bool::from(y_coord.is_odd()) {
y_coord = y_coord.negate(1);
y_coord = y_coord.normalize();
}
vec![vec![1], y_coord.to_bytes().to_vec(), r_inv.to_bytes().to_vec()]
} else {
let nqr_field = K256FieldElement::from_bytes(&NQR.into()).unwrap();
let qr = alpha * nqr_field;
let root = qr.sqrt().expect("if alpha is not a square, then qr should be a square");
vec![vec![0], root.to_bytes().to_vec()]
}
}
pub(super) fn handle_secp256r1(r: [u8; 32], alpha: [u8; 32], r_y_is_odd: bool) -> Vec<Vec<u8>> {
use p256::{
elliptic_curve::ff::PrimeField, FieldElement as P256FieldElement, Scalar as P256Scalar,
};
let r = P256FieldElement::from_bytes(&r.into()).unwrap();
debug_assert!(!bool::from(r.is_zero()), "r should not be zero");
let alpha = P256FieldElement::from_bytes(&alpha.into()).unwrap();
debug_assert!(!bool::from(alpha.is_zero()), "alpha should not be zero");
if let Some(mut y_coord) = alpha.sqrt().into_option() {
let r = P256Scalar::from_repr(r.to_bytes()).unwrap();
let r_inv = r.invert().expect("Non zero r scalar");
if r_y_is_odd != bool::from(y_coord.is_odd()) {
y_coord = -y_coord;
}
vec![vec![1], y_coord.to_bytes().to_vec(), r_inv.to_bytes().to_vec()]
} else {
let nqr_field = P256FieldElement::from_bytes(&NQR.into()).unwrap();
let qr = alpha * nqr_field;
let root = qr.sqrt().expect("if alpha is not a square, then qr should be a square");
vec![vec![0], root.to_bytes().to_vec()]
}
}
}
pub mod fp_ops {
use super::{pad_to_be, BigUint, HookEnv, One, Zero};
#[must_use]
pub fn hook_fp_inverse(_: HookEnv, buf: &[u8]) -> Vec<Vec<u8>> {
let len: usize = u32::from_be_bytes(buf[0..4].try_into().unwrap()) as usize;
assert!(buf.len() == 4 + 2 * len, "FpOp: Invalid buffer length");
let buf = &buf[4..];
let element = BigUint::from_bytes_be(&buf[..len]);
let modulus = BigUint::from_bytes_be(&buf[len..2 * len]);
assert!(!element.is_zero(), "FpOp: Inverse called with zero");
let inverse = element.modpow(&(&modulus - BigUint::from(2u64)), &modulus);
vec![pad_to_be(&inverse, len)]
}
#[must_use]
pub fn hook_fp_sqrt(_: HookEnv, buf: &[u8]) -> Vec<Vec<u8>> {
let len: usize = u32::from_be_bytes(buf[0..4].try_into().unwrap()) as usize;
assert!(buf.len() == 4 + 3 * len, "FpOp: Invalid buffer length");
let buf = &buf[4..];
let element = BigUint::from_bytes_be(&buf[..len]);
let modulus = BigUint::from_bytes_be(&buf[len..2 * len]);
let nqr = BigUint::from_bytes_be(&buf[2 * len..3 * len]);
assert!(
element < modulus,
"Element is not less than modulus, the hook only accepts canonical representations"
);
assert!(
nqr < modulus,
"NQR is zero or non-canonical, the hook only accepts canonical representations"
);
if element.is_zero() {
return vec![vec![1], vec![0; len]];
}
if let Some(root) = sqrt_fp(&element, &modulus, &nqr) {
vec![vec![1], pad_to_be(&root, len)]
} else {
let qr = (&nqr * &element) % &modulus;
let root = sqrt_fp(&qr, &modulus, &nqr).unwrap();
vec![vec![0], pad_to_be(&root, len)]
}
}
fn sqrt_fp(element: &BigUint, modulus: &BigUint, nqr: &BigUint) -> Option<BigUint> {
if modulus % BigUint::from(4u64) == BigUint::from(3u64) {
let maybe_root =
element.modpow(&((modulus + BigUint::from(1u64)) / BigUint::from(4u64)), modulus);
return Some(maybe_root).filter(|root| root * root % modulus == *element);
}
tonelli_shanks(element, modulus, nqr)
}
#[allow(clippy::many_single_char_names)]
fn tonelli_shanks(element: &BigUint, modulus: &BigUint, nqr: &BigUint) -> Option<BigUint> {
if legendre_symbol(element, modulus) != BigUint::one() {
return None;
}
let mut s = BigUint::zero();
let mut q = modulus - BigUint::one();
while &q % &BigUint::from(2u64) == BigUint::zero() {
s += BigUint::from(1u64);
q /= BigUint::from(2u64);
}
let z = nqr;
let mut c = z.modpow(&q, modulus);
let mut r = element.modpow(&((&q + BigUint::from(1u64)) / BigUint::from(2u64)), modulus);
let mut t = element.modpow(&q, modulus);
let mut m = s;
while t != BigUint::one() {
let mut i = BigUint::zero();
let mut tt = t.clone();
while tt != BigUint::one() {
tt = &tt * &tt % modulus;
i += BigUint::from(1u64);
if i == m {
return None;
}
}
let b_pow =
BigUint::from(2u64).pow((&m - &i - BigUint::from(1u64)).try_into().unwrap());
let b = c.modpow(&b_pow, modulus);
r = &r * &b % modulus;
c = &b * &b % modulus;
t = &t * &c % modulus;
m = i;
}
Some(r)
}
fn legendre_symbol(element: &BigUint, modulus: &BigUint) -> BigUint {
assert!(!element.is_zero(), "FpOp: Legendre symbol of zero called.");
element.modpow(&((modulus - BigUint::one()) / BigUint::from(2u64)), modulus)
}
#[cfg(test)]
mod test {
use super::*;
use std::str::FromStr;
#[test]
fn test_legendre_symbol() {
let modulus = BigUint::from_str(
"115792089237316195423570985008687907853269984665640564039457584007908834671663",
)
.unwrap();
let neg_1 = &modulus - BigUint::one();
let fixtures = [
(BigUint::from(4u64), BigUint::from(1u64)),
(BigUint::from(2u64), BigUint::from(1u64)),
(BigUint::from(3u64), neg_1.clone()),
];
for (element, expected) in fixtures {
let result = legendre_symbol(&element, &modulus);
assert_eq!(result, expected);
}
}
#[test]
fn test_tonelli_shanks() {
let p = BigUint::from_str(
"115792089237316195423570985008687907853269984665640564039457584007908834671663",
)
.unwrap();
let nqr = BigUint::from_str("3").unwrap();
let large_element = &p - BigUint::from(u16::MAX);
let square = &large_element * &large_element % &p;
let fixtures = [
(BigUint::from(2u64), true),
(BigUint::from(3u64), false),
(BigUint::from(4u64), true),
(square, true),
];
for (element, expected) in fixtures {
let result = tonelli_shanks(&element, &p, &nqr);
if expected {
assert!(result.is_some());
let result = result.unwrap();
assert!((&result * &result) % &p == element);
} else {
assert!(result.is_none());
}
}
}
}
}
#[must_use]
pub fn hook_ed_decompress(_: HookEnv, buf: &[u8]) -> Vec<Vec<u8>> {
const NQR_CURVE_25519: u8 = 2;
let modulus = sp1_curves::edwards::ed25519::Ed25519BaseField::modulus();
let mut bytes: [u8; 32] = buf[..32].try_into().unwrap();
bytes[31] &= 0b0111_1111;
let y = BigUint::from_bytes_le(&bytes);
if y >= modulus {
return vec![vec![0]];
}
let v = BigUint::from_bytes_le(&buf[32..]);
assert!(v < modulus, "V is not a valid field element");
let v_inv = v.modpow(&(&modulus - BigUint::from(2u64)), &modulus);
let u = (&y * &y + &modulus - BigUint::one()) % &modulus;
let u_div_v = (&u * &v_inv) % &modulus;
if ed25519_sqrt(&u_div_v).is_some() {
vec![vec![1]]
} else {
let qr = (u_div_v * NQR_CURVE_25519) % &modulus;
let root = ed25519_sqrt(&qr).unwrap();
let v_inv_bytes = v_inv.to_bytes_le();
let mut v_inv_padded = [0_u8; 32];
v_inv_padded[..v_inv_bytes.len()].copy_from_slice(&v_inv.to_bytes_le());
let root_bytes = root.to_bytes_le();
let mut root_padded = [0_u8; 32];
root_padded[..root_bytes.len()].copy_from_slice(&root.to_bytes_le());
vec![vec![0], v_inv_padded.to_vec(), root_padded.to_vec()]
}
}
pub mod bls {
use super::{pad_to_be, BigUint, HookEnv};
use sp1_curves::{params::FieldParameters, weierstrass::bls12_381::Bls12381BaseField, Zero};
pub const NQR_BLS12_381: [u8; 48] = {
let mut nqr = [0; 48];
nqr[47] = 2;
nqr
};
pub const BLS12_381_MODULUS: &[u8] = Bls12381BaseField::MODULUS;
#[must_use]
pub fn hook_bls12_381_sqrt(_: HookEnv, buf: &[u8]) -> Vec<Vec<u8>> {
let field_element = BigUint::from_bytes_be(&buf[..48]);
if field_element.is_zero() {
return vec![vec![1], vec![0; 48]];
}
let modulus = BigUint::from_bytes_le(BLS12_381_MODULUS);
let exp = (&modulus + BigUint::from(1u64)) / BigUint::from(4u64);
let sqrt = field_element.modpow(&exp, &modulus);
let square = (&sqrt * &sqrt) % &modulus;
if square != field_element {
let nqr = BigUint::from_bytes_be(&NQR_BLS12_381);
let qr = (&nqr * &field_element) % &modulus;
let root = qr.modpow(&exp, &modulus);
assert!((&root * &root) % &modulus == qr, "NQR sanity check failed, this is a bug.");
return vec![vec![0], pad_to_be(&root, 48)];
}
vec![vec![1], pad_to_be(&sqrt, 48)]
}
#[must_use]
pub fn hook_bls12_381_inverse(_: HookEnv, buf: &[u8]) -> Vec<Vec<u8>> {
let field_element = BigUint::from_bytes_be(&buf[..48]);
assert!(!field_element.is_zero(), "Field element is the additive identity");
let modulus = BigUint::from_bytes_le(BLS12_381_MODULUS);
let inverse = field_element.modpow(&(&modulus - BigUint::from(2u64)), &modulus);
vec![pad_to_be(&inverse, 48)]
}
}
#[must_use]
pub fn hook_rsa_mul_mod(_: HookEnv, buf: &[u8]) -> Vec<Vec<u8>> {
assert!(
buf.len() == 256 * 3 || buf.len() == 384 * 3 || buf.len() == 512 * 3,
"rsa_mul_mod input should have length key_size * 3, this is a bug."
);
let len = buf.len() / 3;
let prod = BigUint::from_bytes_le(&buf[..2 * len]);
let m = BigUint::from_bytes_le(&buf[2 * len..]);
let (q, rem) = prod.div_rem(&m);
let mut rem = rem.to_bytes_le();
rem.resize(len, 0);
let mut q = q.to_bytes_le();
q.resize(len, 0);
vec![rem, q]
}
fn pad_to_be(val: &BigUint, len: usize) -> Vec<u8> {
let mut bytes = val.to_bytes_le();
bytes.resize(len, 0);
bytes.reverse();
bytes
}
#[cfg(test)]
mod tests {
#![allow(clippy::print_stdout)]
use super::*;
#[test]
pub fn registry_new_is_inhabited() {
assert_ne!(HookRegistry::new().table.len(), 0);
println!("{:?}", HookRegistry::new());
}
#[test]
pub fn registry_empty_is_empty() {
assert_eq!(HookRegistry::empty().table.len(), 0);
}
}