use crate::errors::EncodingError;
use snarkvm_curves::traits::{
pairing_engine::{AffineCurve, ProjectiveCurve},
Group,
MontgomeryModelParameters,
TEModelParameters,
};
use snarkvm_fields::{Field, LegendreSymbol, One, SquareRootField, Zero};
use snarkvm_utilities::{to_bytes, FromBytes, ToBytes};
use std::{cmp, marker::PhantomData, ops::Neg};
pub struct Elligator2<P: MontgomeryModelParameters + TEModelParameters, G: Group + ProjectiveCurve> {
_parameters: PhantomData<P>,
_group: PhantomData<G>,
}
impl<P: MontgomeryModelParameters + TEModelParameters, G: Group + ProjectiveCurve> Elligator2<P, G> {
const A: P::BaseField = <P as MontgomeryModelParameters>::COEFF_A;
const B: P::BaseField = <P as MontgomeryModelParameters>::COEFF_B;
const D: P::BaseField = <P as TEModelParameters>::COEFF_D;
#[allow(clippy::many_single_char_names)]
pub fn encode(input: &P::BaseField) -> Result<(<G as ProjectiveCurve>::Affine, bool), EncodingError> {
if input.is_zero() {
return Err(EncodingError::InputMustBeNonzero);
}
let sign_high = input > &input.neg();
let input = if sign_high { *input } else { input.neg() };
let (a, b) = {
let a = Self::A * &Self::B.inverse().unwrap();
let b = P::BaseField::one() * &Self::B.square().inverse().unwrap();
(a, b)
};
let (u, v) = {
let r = input;
let u = Self::D;
let ur2 = r.square() * &u;
{
#[cfg(debug_assertions)]
assert!(u.legendre().is_qnr());
assert_ne!(P::BaseField::one() + &ur2, P::BaseField::zero());
let a2 = a.square();
assert_ne!(a2 * &ur2, (P::BaseField::one() + &ur2).square() * &b);
}
let v = (P::BaseField::one() + &ur2).inverse().unwrap() * &(-a);
let v2 = v.square();
let v3 = v2 * &v;
let av2 = a * &v2;
let bv = b * &v;
let e = (v3 + &(av2 + &bv)).legendre();
let two = P::BaseField::one().double();
let x = match e {
LegendreSymbol::Zero => -(a * &two.inverse().unwrap()),
LegendreSymbol::QuadraticResidue => v,
LegendreSymbol::QuadraticNonResidue => (-v) - &a,
};
let x2 = x.square();
let x3 = x2 * &x;
let ax2 = a * &x2;
let bx = b * &x;
let value = (x3 + &(ax2 + &bx)).sqrt().unwrap();
let y = match e {
LegendreSymbol::Zero => P::BaseField::zero(),
LegendreSymbol::QuadraticResidue => -value,
LegendreSymbol::QuadraticNonResidue => value,
};
(x, y)
};
{
let v2 = v.square();
let u2 = u.square();
let u3 = u2 * &u;
assert_eq!(v2, u3 + &(a * &u2) + &(b * &u));
}
let (s, t) = {
let s = u * &Self::B;
let t = v * &Self::B;
#[cfg(debug_assertions)]
{
let t2 = t.square();
let s2 = s.square();
let s3 = s2 * &s;
assert_eq!(Self::B * &t2, s3 + &(Self::A * &s2) + &s);
}
(s, t)
};
let (x, y) = {
let x = s * &t.inverse().unwrap();
let numerator = s - &P::BaseField::one();
let denominator = s + &P::BaseField::one();
let y = numerator * &denominator.inverse().unwrap();
(x, y)
};
Ok((<G as ProjectiveCurve>::Affine::read(&to_bytes![x, y]?[..])?, sign_high))
}
#[allow(clippy::many_single_char_names)]
pub fn decode(
group_element: &<G as ProjectiveCurve>::Affine,
sign_high: bool,
) -> Result<P::BaseField, EncodingError> {
if group_element.is_zero() {
return Err(EncodingError::InputMustBeNonzero);
}
let x = P::BaseField::read(&to_bytes![group_element.to_x_coordinate()]?[..])?;
let y = P::BaseField::read(&to_bytes![group_element.to_y_coordinate()]?[..])?;
let (a, b) = {
let a = Self::A * &Self::B.inverse().unwrap();
let b = P::BaseField::one() * &Self::B.square().inverse().unwrap();
(a, b)
};
let (u_reconstructed, v_reconstructed) = {
let numerator = P::BaseField::one() + &y;
let denominator = P::BaseField::one() - &y;
let u = numerator * &(denominator.inverse().unwrap());
let v = numerator * &((denominator * &x).inverse().unwrap());
#[cfg(debug_assertions)]
{
let v2 = v.square();
let u2 = u.square();
let u3 = u2 * &u;
assert_eq!(Self::B * &v2, u3 + &(Self::A * &u2) + &u);
}
let u = u * &Self::B.inverse().unwrap();
let v = v * &Self::B.inverse().unwrap();
{
let v2 = v.square();
let u2 = u.square();
let u3 = u2 * &u;
assert_eq!(v2, u3 + &(a * &u2) + &(b * &u));
}
(u, v)
};
let x = u_reconstructed;
let u = Self::D;
{
#[cfg(debug_assertions)]
assert!(u.legendre().is_qnr());
assert_ne!(x, -a);
if y.is_zero() {
assert!(x.is_zero());
}
assert_eq!((-(u * &x) * &(x + &a)).legendre(), LegendreSymbol::QuadraticResidue);
}
let exists_in_sqrt_fq2 = v_reconstructed.square().sqrt().unwrap() == v_reconstructed;
let element = if exists_in_sqrt_fq2 {
let numerator = -x;
let denominator = (x + &a) * &u;
(numerator * &denominator.inverse().unwrap()).sqrt().unwrap()
} else {
let numerator = -x - &a;
let denominator = x * &u;
(numerator * &denominator.inverse().unwrap()).sqrt().unwrap()
};
let element = if sign_high {
cmp::max(element, -element)
} else {
cmp::min(element, -element)
};
#[cfg(debug_assertions)]
assert!(&Self::encode(&element)?.0 == group_element);
Ok(element)
}
}