use std::{fmt::Debug, ops::Rem};
#[cfg(feature = "dev-curves")]
use midnight_curves::bn256;
use midnight_curves::{k256, p256};
use num_bigint::{BigInt, BigInt as BI, ToBigInt};
use num_traits::{One, Signed};
use crate::{ecc::curves::CircuitCurve, CircuitField};
pub trait FieldEmulationParams<F: CircuitField, K: CircuitField>:
Default + Clone + Debug + PartialEq + Eq
{
const LOG2_BASE: u32;
const NB_LIMBS: u32;
fn base_powers() -> Vec<BI> {
let two = BI::from(2);
let m = &K::modulus().to_bigint().unwrap();
(0..Self::NB_LIMBS)
.map(|i| two.pow(Self::LOG2_BASE * i).rem(m))
.collect::<Vec<_>>()
}
fn double_base_powers() -> Vec<BI> {
let two = BI::from(2);
let m = &K::modulus().to_bigint().unwrap();
(0..Self::NB_LIMBS)
.flat_map(|i| {
(0..Self::NB_LIMBS)
.map(|j| two.pow(Self::LOG2_BASE * (i + j)).rem(m))
.collect::<Vec<_>>()
})
.collect::<Vec<_>>()
}
fn moduli() -> Vec<BI>;
fn max_limb_bound() -> BI {
BI::from(2).pow(2 * Self::LOG2_BASE)
}
const RC_LIMB_SIZE: u32;
}
pub(crate) fn check_params<F, K, P>()
where
F: CircuitField,
K: CircuitField,
P: FieldEmulationParams<F, K>,
{
let m = &K::modulus().to_bigint().unwrap();
let base = BI::from(2).pow(P::LOG2_BASE);
let nb_limbs = P::NB_LIMBS;
assert!(*m > BI::one());
assert!(base > BI::one());
assert!(BI::pow(&base, nb_limbs) >= *m);
let base_powers = P::base_powers();
let double_base_powers = P::double_base_powers();
assert_eq!(base_powers.len(), nb_limbs as usize);
assert_eq!(double_base_powers.len(), (nb_limbs * nb_limbs) as usize);
let expected_powers = (0..nb_limbs).map(|i| BI::pow(&base, i).rem(m));
let expected_double_powers = (0..nb_limbs)
.flat_map(|i| (0..nb_limbs).map(|j| BI::pow(&base, i + j).rem(m)).collect::<Vec<_>>());
base_powers
.iter()
.chain(double_base_powers.iter())
.zip(expected_powers.chain(expected_double_powers))
.for_each(|(b, e)| {
assert!(!BI::is_negative(b));
assert_eq!(b.rem(m), e.rem(m))
});
}
#[derive(Clone, Default, Debug, PartialEq, Eq)]
pub struct MultiEmulationParams {}
impl<C: CircuitCurve + Default> FieldEmulationParams<C::ScalarField, C::Base> for C
where
MultiEmulationParams: FieldEmulationParams<C::ScalarField, C::Base>,
{
const LOG2_BASE: u32 = MultiEmulationParams::LOG2_BASE;
const NB_LIMBS: u32 = MultiEmulationParams::NB_LIMBS;
fn moduli() -> Vec<BigInt> {
MultiEmulationParams::moduli()
}
const RC_LIMB_SIZE: u32 = MultiEmulationParams::RC_LIMB_SIZE;
}
#[cfg(feature = "dev-curves")]
impl FieldEmulationParams<bn256::Fr, k256::Fp> for MultiEmulationParams {
const LOG2_BASE: u32 = 64;
const NB_LIMBS: u32 = 4;
fn moduli() -> Vec<BigInt> {
vec![BigInt::from(2).pow(128)]
}
const RC_LIMB_SIZE: u32 = 16;
}
impl FieldEmulationParams<midnight_curves::Fq, k256::Fp> for MultiEmulationParams {
const LOG2_BASE: u32 = 64;
const NB_LIMBS: u32 = 4;
fn moduli() -> Vec<BigInt> {
vec![BigInt::from(2).pow(128)]
}
const RC_LIMB_SIZE: u32 = 16;
}
#[cfg(feature = "dev-curves")]
impl FieldEmulationParams<bn256::Fr, k256::Fq> for MultiEmulationParams {
const LOG2_BASE: u32 = 52;
const NB_LIMBS: u32 = 5;
fn moduli() -> Vec<BigInt> {
vec![BigInt::from(2).pow(141)]
}
const RC_LIMB_SIZE: u32 = 14;
}
impl FieldEmulationParams<midnight_curves::Fq, k256::Fq> for MultiEmulationParams {
const LOG2_BASE: u32 = 64;
const NB_LIMBS: u32 = 4;
fn moduli() -> Vec<BigInt> {
vec![
BigInt::from(2).pow(118),
BigInt::from(2).pow(118) - BigInt::one(),
]
}
const RC_LIMB_SIZE: u32 = 17;
}
impl FieldEmulationParams<midnight_curves::Fq, p256::Fp> for MultiEmulationParams {
const LOG2_BASE: u32 = 64;
const NB_LIMBS: u32 = 4;
fn moduli() -> Vec<BigInt> {
vec![
BigInt::from(2).pow(122),
BigInt::from(2).pow(122) - BigInt::from(507376),
]
}
const RC_LIMB_SIZE: u32 = 17;
}
impl FieldEmulationParams<midnight_curves::Fq, p256::Fq> for MultiEmulationParams {
const LOG2_BASE: u32 = 64;
const NB_LIMBS: u32 = 4;
fn moduli() -> Vec<BigInt> {
vec![
BigInt::from(2).pow(118),
BigInt::from(2).pow(118) - BigInt::one(),
]
}
const RC_LIMB_SIZE: u32 = 17;
}
impl FieldEmulationParams<midnight_curves::Fq, midnight_curves::Fp> for MultiEmulationParams {
const LOG2_BASE: u32 = 56;
const NB_LIMBS: u32 = 7;
fn moduli() -> Vec<BigInt> {
vec![
BigInt::from(2).pow(134),
BigInt::from(2).pow(134) - BigInt::from(1),
]
}
const RC_LIMB_SIZE: u32 = 15;
}
#[cfg(feature = "dev-curves")]
impl FieldEmulationParams<midnight_curves::Fq, bn256::Fq> for MultiEmulationParams {
const LOG2_BASE: u32 = 52;
const NB_LIMBS: u32 = 5;
fn moduli() -> Vec<BigInt> {
vec![BigInt::from(2).pow(142)]
}
const RC_LIMB_SIZE: u32 = 14;
}
#[cfg(feature = "dev-curves")]
impl FieldEmulationParams<bn256::Fr, bn256::Fq> for MultiEmulationParams {
const LOG2_BASE: u32 = 52;
const NB_LIMBS: u32 = 5;
fn moduli() -> Vec<BigInt> {
vec![BigInt::from(2).pow(141)]
}
const RC_LIMB_SIZE: u32 = 14;
}
impl FieldEmulationParams<midnight_curves::Fq, midnight_curves::Fq> for MultiEmulationParams {
const LOG2_BASE: u32 = 52;
const NB_LIMBS: u32 = 5;
fn moduli() -> Vec<BigInt> {
vec![BigInt::from(2).pow(142)]
}
const RC_LIMB_SIZE: u32 = 14;
}
impl FieldEmulationParams<midnight_curves::Fq, midnight_curves::curve25519::Fp>
for MultiEmulationParams
{
const LOG2_BASE: u32 = 64;
const NB_LIMBS: u32 = 4;
fn moduli() -> Vec<BigInt> {
vec![BigInt::from(2).pow(128)]
}
const RC_LIMB_SIZE: u32 = 16;
}
impl FieldEmulationParams<midnight_curves::Fq, midnight_curves::curve25519::Scalar>
for MultiEmulationParams
{
const LOG2_BASE: u32 = 51;
const NB_LIMBS: u32 = 5;
fn moduli() -> Vec<BigInt> {
vec![BigInt::from(2).pow(146)]
}
const RC_LIMB_SIZE: u32 = 17;
}