use ark_bn254::{Fq, Fq2, Fr, G1Affine, G1Projective, G2Affine, G2Projective};
use ark_ff::BigInteger;
use ark_ff::PrimeField;
use fastcrypto::error::FastCryptoError;
use num_bigint::BigUint;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
pub type CircomG1 = Vec<Bn254FqElement>;
pub type CircomG2 = Vec<Vec<Bn254FqElement>>;
#[derive(Debug, Clone, JsonSchema, Eq, PartialEq)]
pub struct Bn254FqElement(#[schemars(with = "String")] [u8; 32]);
impl std::str::FromStr for Bn254FqElement {
type Err = FastCryptoError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let big_int = Fq::from_str(s).map_err(|_| FastCryptoError::InvalidInput)?;
let be_bytes = big_int.into_bigint().to_bytes_be();
be_bytes
.try_into()
.map_err(|_| FastCryptoError::InvalidInput)
.map(Bn254FqElement)
}
}
impl std::fmt::Display for Bn254FqElement {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let big_int = BigUint::from_bytes_be(&self.0);
let radix10 = big_int.to_string();
f.write_str(&radix10)
}
}
impl Serialize for Bn254FqElement {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
self.to_string().serialize(serializer)
}
}
impl<'de> Deserialize<'de> for Bn254FqElement {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = std::borrow::Cow::<'de, str>::deserialize(deserializer)?;
std::str::FromStr::from_str(&s).map_err(serde::de::Error::custom)
}
}
#[derive(Debug, Clone, JsonSchema, Eq, PartialEq)]
pub struct Bn254FrElement(#[schemars(with = "String")] [u8; 32]);
impl Bn254FrElement {
pub fn unpadded(&self) -> &[u8] {
let mut buf = self.0.as_slice();
while !buf.is_empty() && buf[0] == 0 {
buf = &buf[1..];
}
if buf.is_empty() {
&self.0[31..]
} else {
buf
}
}
pub fn padded(&self) -> &[u8] {
&self.0
}
}
impl std::str::FromStr for Bn254FrElement {
type Err = FastCryptoError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let big_int = Fr::from_str(s).map_err(|_| FastCryptoError::InvalidInput)?;
let be_bytes = big_int.into_bigint().to_bytes_be();
be_bytes
.try_into()
.map_err(|_| FastCryptoError::InvalidInput)
.map(Bn254FrElement)
}
}
impl std::fmt::Display for Bn254FrElement {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let big_int = BigUint::from_bytes_be(&self.0);
let radix10 = big_int.to_string();
f.write_str(&radix10)
}
}
impl Serialize for Bn254FrElement {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
self.to_string().serialize(serializer)
}
}
impl<'de> Deserialize<'de> for Bn254FrElement {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let s = std::borrow::Cow::<'de, str>::deserialize(deserializer)?;
std::str::FromStr::from_str(&s).map_err(serde::de::Error::custom)
}
}
impl From<&Bn254FqElement> for Fq {
fn from(f: &Bn254FqElement) -> Self {
Fq::from_be_bytes_mod_order(&f.0)
}
}
impl From<&Bn254FrElement> for Fr {
fn from(f: &Bn254FrElement) -> Self {
Fr::from_be_bytes_mod_order(&f.0)
}
}
pub(crate) fn g1_affine_from_str_projective(s: &CircomG1) -> Result<G1Affine, FastCryptoError> {
if s.len() != 3 {
return Err(FastCryptoError::InvalidInput);
}
let g1: G1Affine =
G1Projective::new_unchecked((&s[0]).into(), (&s[1]).into(), (&s[2]).into()).into();
if !g1.is_on_curve() || !g1.is_in_correct_subgroup_assuming_on_curve() {
return Err(FastCryptoError::InvalidInput);
}
Ok(g1)
}
pub(crate) fn g2_affine_from_str_projective(s: &CircomG2) -> Result<G2Affine, FastCryptoError> {
if s.len() != 3 || s[0].len() != 2 || s[1].len() != 2 || s[2].len() != 2 {
return Err(FastCryptoError::InvalidInput);
}
let g2: G2Affine = G2Projective::new_unchecked(
Fq2::new((&s[0][0]).into(), (&s[0][1]).into()),
Fq2::new((&s[1][0]).into(), (&s[1][1]).into()),
Fq2::new((&s[2][0]).into(), (&s[2][1]).into()),
)
.into();
if !g2.is_on_curve() || !g2.is_in_correct_subgroup_assuming_on_curve() {
return Err(FastCryptoError::InvalidInput);
}
Ok(g2)
}
#[cfg(test)]
mod test {
use crate::zk_login_utils::Bn254FqElement;
use std::str::FromStr;
use super::Bn254FrElement;
use num_bigint::BigUint;
use proptest::prelude::*;
#[test]
fn from_str_on_digits_only() {
assert!(Bn254FrElement::from_str("10_________0").is_err());
assert!(Bn254FqElement::from_str("10_________0").is_err());
assert!(Bn254FrElement::from_str("000001").is_err());
assert!(Bn254FqElement::from_str("000001").is_err());
assert!(Bn254FrElement::from_str("garbage").is_err());
assert!(Bn254FqElement::from_str("garbage").is_err());
}
#[test]
fn unpadded_slice() {
let seed = Bn254FrElement([0; 32]);
let zero: [u8; 1] = [0];
assert_eq!(seed.unpadded(), zero.as_slice());
let mut seed = Bn254FrElement([1; 32]);
seed.0[0] = 0;
assert_eq!(seed.unpadded(), [1; 31].as_slice());
}
proptest! {
#[test]
fn dont_crash_on_large_inputs(
bytes in proptest::collection::vec(any::<u8>(), 33..1024)
) {
let big_int = BigUint::from_bytes_be(&bytes);
let radix10 = big_int.to_str_radix(10);
let _ = Bn254FrElement::from_str(&radix10);
let _ = Bn254FqElement::from_str(&radix10);
}
}
}