use ff::{Field, PrimeField};
use group::cofactor::CofactorGroup;
use midnight_curves::{ff_ext::Legendre, JubjubExtended as Jubjub, JubjubSubgroup};
use subtle::{ConditionallySelectable, ConstantTimeEq};
use super::mtc_params::{MapToEdwardsParams, MapToWeierstrassParams};
use crate::ecc::curves::CircuitCurve;
pub trait MapToCurveCPU<C: CircuitCurve> {
fn map_to_curve(u: &C::Base) -> C::CryptographicGroup;
}
impl MapToCurveCPU<Jubjub> for Jubjub {
fn map_to_curve(u: &<Jubjub as CircuitCurve>::Base) -> JubjubSubgroup {
let (x, y) = svdw_map_to_curve::<Jubjub>(u);
let (x, y) = weierstrass_to_montgomery::<Jubjub>(&x, &y);
let (x, y) = montgomery_to_edwards::<Jubjub>(&x, &y);
let extended_point = Jubjub::from_xy(x, y).unwrap();
<Jubjub as CofactorGroup>::clear_cofactor(&extended_point)
}
}
fn svdw_map_to_curve<C>(u: &C::Base) -> (C::Base, C::Base)
where
C: CircuitCurve + MapToWeierstrassParams<C::Base>,
C::Base: Legendre,
{
let tv1 = u.square();
let tv1 = tv1 * C::c1();
let tv2 = C::Base::ONE + tv1;
let tv1 = C::Base::ONE - tv1;
let tv3 = tv1 * tv2;
let tv3 = tv3.invert().unwrap_or(C::Base::ZERO);
let tv4 = *u * tv1;
let tv4 = tv4 * tv3;
let tv4 = tv4 * C::c3();
let x1 = C::c2() - tv4;
let gx1 = x1.square();
let gx1 = gx1 + C::A;
let gx1 = gx1 * x1;
let gx1 = gx1 + C::B;
let e1 = !gx1.ct_quadratic_non_residue();
let x2 = C::c2() + tv4;
let gx2 = x2.square();
let gx2 = gx2 + C::A;
let gx2 = gx2 * x2;
let gx2 = gx2 + C::B;
let e2 = !gx2.ct_quadratic_non_residue() & (!e1);
let x3 = tv2.square();
let x3 = x3 * tv3;
let x3 = x3.square();
let x3 = x3 * C::c4();
let x3 = x3 + C::SVDW_Z;
let x = C::Base::conditional_select(&x3, &x1, e1);
let x = C::Base::conditional_select(&x, &x2, e2);
let gx = x.square();
let gx = gx + C::A;
let gx = gx * x;
let gx = gx + C::B;
let y = gx.sqrt().unwrap();
let e3 = u.is_odd().ct_eq(&y.is_odd());
let y = C::Base::conditional_select(&-y, &y, e3);
(x, y)
}
fn weierstrass_to_montgomery<C>(x: &C::Base, y: &C::Base) -> (C::Base, C::Base)
where
C: CircuitCurve + MapToEdwardsParams<C::Base>,
{
let x_prime = *x * C::MONT_K;
let x_prime = x_prime - C::MONT_J * C::Base::from(3).invert().unwrap();
let y_prime = *y * C::MONT_K;
(x_prime, y_prime)
}
fn montgomery_to_edwards<C>(x: &C::Base, y: &C::Base) -> (C::Base, C::Base)
where
C: CircuitCurve + MapToEdwardsParams<C::Base>,
{
let mut tv1 = *x + C::Base::ONE;
let mut tv2 = tv1 * *y;
tv2 = tv2.invert().unwrap_or(C::Base::ZERO);
let mut x_prime = tv2 * tv1;
x_prime *= *x;
let mut y_prime = tv2 * *y;
tv1 = *x - C::Base::ONE;
y_prime *= tv1;
let e = tv2 == C::Base::ZERO;
y_prime = C::Base::conditional_select(&y_prime, &C::Base::ONE, (e as u8).into());
(x_prime, y_prime)
}