use std::sync::OnceLock;
use crate::primes::{is_probable_prime, mod_inverse, random_below};
use crate::bigint::{BigInt, BigUint, Sign};
use crate::csprng::Csprng;
#[derive(Clone, Copy, Debug)]
struct ReductionTerm {
offset: usize,
coef: i64,
}
#[derive(Clone, Debug)]
struct ReductionParams {
k: usize,
terms: &'static [ReductionTerm],
p: BigUint,
name: &'static str,
prefer_fast: bool,
}
#[derive(Clone, Debug)]
enum FieldKind {
Generic,
Mersenne127,
Reduction(Box<ReductionParams>),
}
#[derive(Clone, Debug)]
pub struct PrimeField {
p: BigUint,
kind: FieldKind,
}
fn detect_kind(p: &BigUint) -> FieldKind {
if p == cached_mersenne127() {
return FieldKind::Mersenne127;
}
for params in known_reductions() {
if ¶ms.p == p {
if params.prefer_fast {
return FieldKind::Reduction(Box::new(params.clone()));
}
return FieldKind::Generic;
}
}
FieldKind::Generic
}
fn cached_mersenne127() -> &'static BigUint {
static M127: OnceLock<BigUint> = OnceLock::new();
M127.get_or_init(mersenne127)
}
fn known_reductions() -> &'static [ReductionParams] {
static TABLE: OnceLock<Vec<ReductionParams>> = OnceLock::new();
TABLE.get_or_init(build_known_reductions).as_slice()
}
fn build_known_reductions() -> Vec<ReductionParams> {
static MERSENNE_TERMS: &[ReductionTerm] = &[ReductionTerm { offset: 0, coef: 1 }];
static CURVE25519_TERMS: &[ReductionTerm] = &[ReductionTerm { offset: 0, coef: 19 }];
static POLY1305_TERMS: &[ReductionTerm] = &[ReductionTerm { offset: 0, coef: 5 }];
static SECP256K1_TERMS: &[ReductionTerm] = &[
ReductionTerm { offset: 0, coef: 977 },
ReductionTerm { offset: 32, coef: 1 },
];
static CURVE448_TERMS: &[ReductionTerm] = &[
ReductionTerm { offset: 0, coef: 1 },
ReductionTerm { offset: 224, coef: 1 },
];
static P192_TERMS: &[ReductionTerm] = &[
ReductionTerm { offset: 0, coef: 1 },
ReductionTerm { offset: 64, coef: 1 },
];
static P224_TERMS: &[ReductionTerm] = &[
ReductionTerm { offset: 0, coef: -1 },
ReductionTerm { offset: 96, coef: 1 },
];
static P256_TERMS: &[ReductionTerm] = &[
ReductionTerm { offset: 0, coef: 1 },
ReductionTerm { offset: 96, coef: -1 },
ReductionTerm { offset: 192, coef: -1 },
ReductionTerm { offset: 224, coef: 1 },
];
static P384_TERMS: &[ReductionTerm] = &[
ReductionTerm { offset: 0, coef: 1 },
ReductionTerm { offset: 32, coef: -1 },
ReductionTerm { offset: 96, coef: 1 },
ReductionTerm { offset: 128, coef: 1 },
];
let table = vec![
ReductionParams { k: 521, terms: MERSENNE_TERMS, p: mersenne521(), name: "mersenne521", prefer_fast: true },
ReductionParams { k: 255, terms: CURVE25519_TERMS, p: curve25519_field(), name: "curve25519", prefer_fast: true },
ReductionParams { k: 130, terms: POLY1305_TERMS, p: poly1305_field(), name: "poly1305", prefer_fast: true },
ReductionParams { k: 256, terms: SECP256K1_TERMS, p: secp256k1_field(), name: "secp256k1", prefer_fast: true },
ReductionParams { k: 448, terms: CURVE448_TERMS, p: curve448_field(), name: "curve448", prefer_fast: true },
ReductionParams { k: 192, terms: P192_TERMS, p: nist_p192_field(), name: "nist_p192", prefer_fast: true },
ReductionParams { k: 224, terms: P224_TERMS, p: nist_p224_field(), name: "nist_p224", prefer_fast: true },
ReductionParams { k: 256, terms: P256_TERMS, p: nist_p256_field(), name: "nist_p256", prefer_fast: false },
ReductionParams { k: 384, terms: P384_TERMS, p: nist_p384_field(), name: "nist_p384", prefer_fast: true },
];
for params in &table {
validate_reduction_params(params);
}
table
}
fn validate_reduction_params(params: &ReductionParams) {
for term in params.terms {
assert!(
term.coef != 0,
"{}: zero coefficient in reduction polynomial",
params.name,
);
assert!(
term.offset < params.k,
"{}: term offset {} ≥ k = {}",
params.name,
term.offset,
params.k,
);
}
let mut delta = BigInt::zero();
for term in params.terms {
let mut shifted = BigUint::one();
if term.offset > 0 {
shifted.shl_bits(term.offset);
}
let coef_abs = term.coef.unsigned_abs();
let term_mag = if coef_abs == 1 {
shifted
} else {
shifted.mul_ref(&BigUint::from_u64(coef_abs))
};
let term_int = if term.coef > 0 {
BigInt::from_biguint(term_mag)
} else {
BigInt::from_parts(Sign::Negative, term_mag)
};
delta = delta.add_ref(&term_int);
}
assert!(
delta.sign() == Sign::Positive,
"{}: δ must be positive (got sign {:?}) — fold algorithm assumes positive products stay non-negative",
params.name,
delta.sign(),
);
let mut two_k = BigUint::one();
two_k.shl_bits(params.k);
let expected_delta = two_k.sub_ref(¶ms.p);
assert!(
delta.sign() == Sign::Positive && delta.magnitude() == &expected_delta,
"{}: δ decomposition does not match 2^k − p",
params.name,
);
}
impl PrimeField {
#[must_use]
pub fn new(p: BigUint) -> Self {
assert!(p > BigUint::one(), "modulus must be > 1");
assert!(
is_probable_prime(&p),
"modulus must be prime (Miller–Rabin test)",
);
let kind = detect_kind(&p);
Self { p, kind }
}
#[must_use]
pub fn new_unchecked(p: BigUint) -> Self {
assert!(p > BigUint::one(), "modulus must be > 1");
let kind = detect_kind(&p);
Self { p, kind }
}
#[must_use]
pub fn modulus(&self) -> &BigUint {
&self.p
}
#[must_use]
pub fn reduce(&self, a: &BigUint) -> BigUint {
a.modulo(&self.p)
}
#[must_use]
pub fn add(&self, a: &BigUint, b: &BigUint) -> BigUint {
let s = a.add_ref(b);
s.modulo(&self.p)
}
#[must_use]
pub fn sub(&self, a: &BigUint, b: &BigUint) -> BigUint {
let a = self.reduce(a);
let b = self.reduce(b);
if a >= b {
a.sub_ref(&b)
} else {
a.add_ref(&self.p).sub_ref(&b)
}
}
#[must_use]
pub fn neg(&self, a: &BigUint) -> BigUint {
let a = self.reduce(a);
if a.is_zero() {
BigUint::zero()
} else {
self.p.sub_ref(&a)
}
}
#[must_use]
pub fn mul(&self, a: &BigUint, b: &BigUint) -> BigUint {
match &self.kind {
FieldKind::Mersenne127 => mersenne127_mul(a, b),
FieldKind::Reduction(params) => reduction_mul(a, b, params),
FieldKind::Generic => BigUint::mod_mul(a, b, &self.p),
}
}
#[must_use]
pub fn inv(&self, a: &BigUint) -> Option<BigUint> {
let a = self.reduce(a);
if a.is_zero() {
return None;
}
mod_inverse(&a, &self.p)
}
#[must_use]
pub fn random<R: Csprng>(&self, rng: &mut R) -> BigUint {
random_below(rng, &self.p).expect("modulus > 0")
}
}
fn mersenne127_mul(a: &BigUint, b: &BigUint) -> BigUint {
let p = mersenne127();
let a128 = if a.bits() <= 127 {
a.low_u128()
} else {
a.modulo(&p).low_u128()
};
let b128 = if b.bits() <= 127 {
b.low_u128()
} else {
b.modulo(&p).low_u128()
};
BigUint::from_u128(mul_mod_mersenne127(a128, b128))
}
#[inline]
fn mul_mod_mersenne127(a: u128, b: u128) -> u128 {
let al = a as u64;
let ah = (a >> 64) as u64;
let bl = b as u64;
let bh = (b >> 64) as u64;
let p00 = u128::from(al) * u128::from(bl);
let p01 = u128::from(al) * u128::from(bh);
let p10 = u128::from(ah) * u128::from(bl);
let p11 = u128::from(ah) * u128::from(bh);
let r0 = p00 as u64;
let mid = (p00 >> 64) + u128::from(p01 as u64) + u128::from(p10 as u64);
let r1 = mid as u64;
let mid_hi = (mid >> 64) + (p01 >> 64) + (p10 >> 64) + u128::from(p11 as u64);
let r2 = mid_hi as u64;
let r3 = ((mid_hi >> 64) + (p11 >> 64)) as u64;
let low = u128::from(r0) | (u128::from(r1 & 0x7FFF_FFFF_FFFF_FFFF) << 64);
let high_lo = (r1 >> 63) | (r2 << 1);
let high_hi = (r2 >> 63) | (r3 << 1);
let high = u128::from(high_lo) | (u128::from(high_hi) << 64);
let sum = low + high;
let mask127 = (1u128 << 127) - 1; let folded = (sum & mask127) + (sum >> 127);
if folded >= mask127 { folded - mask127 } else { folded }
}
fn reduction_mul(a: &BigUint, b: &BigUint, params: &ReductionParams) -> BigUint {
let a_red;
let a = if a.bits() <= params.k { a } else { a_red = a.modulo(¶ms.p); &a_red };
let b_red;
let b = if b.bits() <= params.k { b } else { b_red = b.modulo(¶ms.p); &b_red };
let prod = a.mul_ref(b);
let mut t = BigInt::from_biguint(prod);
const MAX_FOLDS: usize = 32;
let mut folds = 0usize;
while needs_fold(&t, params.k) {
assert!(
folds < MAX_FOLDS,
"reduction_mul did not converge for {} after {} folds",
params.name,
MAX_FOLDS,
);
t = reduction_fold(&t, params);
folds += 1;
}
t.modulo_positive(¶ms.p)
}
fn needs_fold(t: &BigInt, k: usize) -> bool {
match t.sign() {
Sign::Zero => false,
Sign::Negative => true,
Sign::Positive => t.magnitude().bits() > k,
}
}
fn reduction_fold(t: &BigInt, params: &ReductionParams) -> BigInt {
if t.sign() == Sign::Zero {
return BigInt::zero();
}
if t.sign() == Sign::Negative {
unreachable!(
"reduction_fold on Negative t — δ > 0 invariant violated; check validate_reduction_params",
);
}
let mag = t.magnitude();
let high = mag.shr_bits(params.k);
let low = mag.low_bits(params.k);
if high.is_zero() {
return BigInt::from_biguint(low);
}
let mut pos = low;
let mut neg = BigUint::zero();
for term in params.terms {
let mut shifted = high.clone();
if term.offset > 0 {
shifted.shl_bits(term.offset);
}
let abs_coef = term.coef.unsigned_abs();
let term_mag = if abs_coef == 1 {
shifted
} else {
shifted.mul_ref(&BigUint::from_u64(abs_coef))
};
if term.coef > 0 {
pos = pos.add_ref(&term_mag);
} else {
neg = neg.add_ref(&term_mag);
}
}
debug_assert!(
pos >= neg,
"δ > 0 invariant violated at runtime: pos < neg in fold step",
);
BigInt::from_biguint(pos.sub_ref(&neg))
}
#[must_use]
pub fn mersenne127() -> BigUint {
let mut v = BigUint::one();
v.shl_bits(127);
v.sub_ref(&BigUint::one())
}
#[must_use]
pub fn mersenne521() -> BigUint {
let mut v = BigUint::one();
v.shl_bits(521);
v.sub_ref(&BigUint::one())
}
#[must_use]
pub fn curve25519_field() -> BigUint {
let mut v = BigUint::one();
v.shl_bits(255);
v.sub_ref(&BigUint::from_u64(19))
}
#[must_use]
pub fn poly1305_field() -> BigUint {
let mut v = BigUint::one();
v.shl_bits(130);
v.sub_ref(&BigUint::from_u64(5))
}
#[must_use]
pub fn secp256k1_field() -> BigUint {
let mut v = BigUint::one();
v.shl_bits(256);
v.sub_ref(&BigUint::from_u64(4_294_968_273))
}
#[must_use]
pub fn curve448_field() -> BigUint {
let mut v = BigUint::one();
v.shl_bits(448);
let mut sub = BigUint::one();
sub.shl_bits(224);
let one = BigUint::one();
v.sub_ref(&sub).sub_ref(&one)
}
#[must_use]
pub fn nist_p192_field() -> BigUint {
let mut v = BigUint::one();
v.shl_bits(192);
let mut sub = BigUint::one();
sub.shl_bits(64);
v.sub_ref(&sub).sub_ref(&BigUint::one())
}
#[must_use]
pub fn nist_p224_field() -> BigUint {
let mut v = BigUint::one();
v.shl_bits(224);
let mut sub = BigUint::one();
sub.shl_bits(96);
v.sub_ref(&sub).add_ref(&BigUint::one())
}
#[must_use]
pub fn nist_p256_field() -> BigUint {
let mut v = BigUint::one();
v.shl_bits(256);
let mut t224 = BigUint::one();
t224.shl_bits(224);
let mut t192 = BigUint::one();
t192.shl_bits(192);
let mut t96 = BigUint::one();
t96.shl_bits(96);
v.sub_ref(&t224).add_ref(&t192).add_ref(&t96).sub_ref(&BigUint::one())
}
#[must_use]
pub fn nist_p384_field() -> BigUint {
let mut v = BigUint::one();
v.shl_bits(384);
let mut t128 = BigUint::one();
t128.shl_bits(128);
let mut t96 = BigUint::one();
t96.shl_bits(96);
let mut t32 = BigUint::one();
t32.shl_bits(32);
v.sub_ref(&t128).sub_ref(&t96).add_ref(&t32).sub_ref(&BigUint::one())
}
#[cfg(test)]
mod tests {
use super::*;
fn small() -> PrimeField {
PrimeField::new(BigUint::from_u64(257))
}
#[test]
fn add_sub_round_trip() {
let f = small();
let a = BigUint::from_u64(123);
let b = BigUint::from_u64(200);
let s = f.add(&a, &b);
assert_eq!(f.sub(&s, &b), a);
assert_eq!(f.sub(&s, &a), b);
}
#[test]
fn sub_underflow_wraps() {
let f = small();
let a = BigUint::from_u64(5);
let b = BigUint::from_u64(10);
assert_eq!(f.sub(&a, &b), BigUint::from_u64(252));
}
#[test]
fn neg_round_trip() {
let f = small();
for i in 0u64..20 {
let a = BigUint::from_u64(i);
assert_eq!(f.add(&a, &f.neg(&a)), BigUint::zero());
}
}
#[test]
fn inv_round_trip() {
let f = small();
for i in 1u64..20 {
let a = BigUint::from_u64(i);
let inv = f.inv(&a).expect("nonzero invertible mod prime");
assert_eq!(f.mul(&a, &inv), BigUint::one());
}
assert!(f.inv(&BigUint::zero()).is_none());
}
#[test]
#[should_panic(expected = "modulus must be prime")]
fn new_rejects_composite_modulus() {
let _ = PrimeField::new(BigUint::from_u64(255));
}
#[test]
fn new_unchecked_skips_primality_check() {
let f = PrimeField::new_unchecked(BigUint::from_u64(255));
assert_eq!(f.modulus(), &BigUint::from_u64(255));
}
#[test]
fn mersenne127_value() {
let p = mersenne127();
assert_eq!(p.bits(), 127);
let next = p.add_ref(&BigUint::one());
let mut two_pow_127 = BigUint::one();
two_pow_127.shl_bits(127);
assert_eq!(next, two_pow_127);
}
#[test]
fn mersenne127_mul_matches_generic_on_random_fuzz() {
use crate::csprng::ChaCha20Rng;
let p = mersenne127();
let fast = PrimeField::new_unchecked(p.clone());
let generic = PrimeField {
p: p.clone(),
kind: FieldKind::Generic,
};
let edges = [
BigUint::zero(),
BigUint::one(),
BigUint::from_u64(2),
p.sub_ref(&BigUint::one()), p.clone(), ];
for a in &edges {
for b in &edges {
assert_eq!(fast.mul(a, b), generic.mul(a, b), "mismatch on edge case");
}
}
let mut r = ChaCha20Rng::from_seed(&[0x4Au8; 32]);
for _ in 0..256 {
let a = fast.random(&mut r);
let b = fast.random(&mut r);
assert_eq!(fast.mul(&a, &b), generic.mul(&a, &b));
}
}
#[test]
fn mersenne127_mul_handles_unreduced_inputs() {
let p = mersenne127();
let fast = PrimeField::new_unchecked(p.clone());
let generic = PrimeField {
p: p.clone(),
kind: FieldKind::Generic,
};
let a = p.add_ref(&BigUint::from_u64(5));
let b = p.add_ref(&p).add_ref(&BigUint::from_u64(3));
assert_eq!(fast.mul(&a, &b), generic.mul(&a, &b));
}
use crate::csprng::ChaCha20Rng;
fn generic_for(p: &BigUint) -> PrimeField {
PrimeField {
p: p.clone(),
kind: FieldKind::Generic,
}
}
fn force_reduction_for(p: &BigUint) -> PrimeField {
for params in super::known_reductions() {
if ¶ms.p == p {
return PrimeField {
p: p.clone(),
kind: super::FieldKind::Reduction(Box::new(params.clone())),
};
}
}
unreachable!("force_reduction_for: prime not in table");
}
fn check_edges(p: &BigUint, fast: &PrimeField, generic: &PrimeField) {
let zero = BigUint::zero();
let one = BigUint::one();
let p_minus_1 = p.sub_ref(&one);
let p_plus_1 = p.add_ref(&one);
let mut two_pow_k_minus_1 = BigUint::one();
two_pow_k_minus_1.shl_bits(p.bits() - 1);
let edges: Vec<BigUint> = vec![
zero.clone(),
one.clone(),
BigUint::from_u64(2),
two_pow_k_minus_1,
p_minus_1,
p.clone(),
p_plus_1,
];
for a in &edges {
for b in &edges {
let want = generic.mul(a, b);
let got = fast.mul(a, b);
assert_eq!(got, want, "edge mismatch: a.bits()={}, b.bits()={}", a.bits(), b.bits());
assert!(got < *p, "result not reduced: bits={}", got.bits());
}
}
assert_eq!(fast.mul(&BigUint::from_u64(0xC0FFEE), &zero), zero);
let small = BigUint::from_u64(0xBEEF);
assert_eq!(fast.mul(&small, &one), small);
assert_eq!(fast.mul(&small, p), zero);
}
fn fuzz_against_generic(p: &BigUint, fast: &PrimeField, generic: &PrimeField, seed_byte: u8) {
let mut r = ChaCha20Rng::from_seed(&[seed_byte; 32]);
const N: usize = 16_384;
for i in 0..N {
let a = fast.random(&mut r);
let b = fast.random(&mut r);
let got = fast.mul(&a, &b);
let want = generic.mul(&a, &b);
assert_eq!(got, want, "fuzz iter {i}: fast {got:?} != generic {want:?}");
assert!(got < *p, "result not reduced at iter {i}");
assert_eq!(fast.mul(&b, &a), got, "non-commutative at iter {i}");
}
}
fn check_unreduced_inputs(p: &BigUint, fast: &PrimeField, generic: &PrimeField) {
let a = p.add_ref(&BigUint::from_u64(5));
let b = p.add_ref(p).add_ref(&BigUint::from_u64(3));
assert_eq!(fast.mul(&a, &b), generic.mul(&a, &b));
let mut big = p.clone();
big.shl_bits(50);
let huge = big.add_ref(&BigUint::from_u64(1234));
assert_eq!(fast.mul(&huge, &BigUint::from_u64(7)), generic.mul(&huge, &BigUint::from_u64(7)));
}
fn check_worst_case_convergence(p: &BigUint, fast: &PrimeField, generic: &PrimeField) {
let p_minus_1 = p.sub_ref(&BigUint::one());
let want = generic.mul(&p_minus_1, &p_minus_1);
let got = fast.mul(&p_minus_1, &p_minus_1);
assert_eq!(got, want, "(p − 1)² mismatch");
assert_eq!(fast.mul(p, p), BigUint::zero());
assert_eq!(fast.mul(p, &BigUint::one()), BigUint::zero());
}
fn full_prime_check(name: &'static str, p: BigUint, expected_bits: usize, seed: u8) {
assert_eq!(p.bits(), expected_bits, "{name}: bit length");
let kind = super::detect_kind(&p);
let fast = if matches!(kind, super::FieldKind::Mersenne127 | super::FieldKind::Reduction(_)) {
PrimeField::new_unchecked(p.clone())
} else {
force_reduction_for(&p)
};
let generic = generic_for(&p);
check_edges(&p, &fast, &generic);
check_unreduced_inputs(&p, &fast, &generic);
check_worst_case_convergence(&p, &fast, &generic);
fuzz_against_generic(&p, &fast, &generic, seed);
}
#[test]
fn fuzz_mersenne521() {
full_prime_check("mersenne521", mersenne521(), 521, 0x21);
}
#[test]
fn fuzz_curve25519() {
full_prime_check("curve25519", curve25519_field(), 255, 0x25);
}
#[test]
fn fuzz_poly1305() {
full_prime_check("poly1305", poly1305_field(), 130, 0x05);
}
#[test]
fn fuzz_secp256k1() {
full_prime_check("secp256k1", secp256k1_field(), 256, 0x6B);
}
#[test]
fn fuzz_curve448() {
full_prime_check("curve448", curve448_field(), 448, 0x48);
}
#[test]
fn fuzz_nist_p192() {
full_prime_check("nist_p192", nist_p192_field(), 192, 0x92);
}
#[test]
fn fuzz_nist_p224() {
full_prime_check("nist_p224", nist_p224_field(), 224, 0x24);
}
#[test]
fn fuzz_nist_p256() {
full_prime_check("nist_p256", nist_p256_field(), 256, 0x56);
}
#[test]
fn fuzz_nist_p384() {
full_prime_check("nist_p384", nist_p384_field(), 384, 0x84);
}
#[test]
fn canonical_hex_values_match_standards() {
let p256 = nist_p256_field();
let expected_p256 = BigUint::from_be_bytes(&hex_decode(
"ffffffff00000001000000000000000000000000ffffffffffffffffffffffff",
));
assert_eq!(p256, expected_p256, "NIST P-256");
let p384 = nist_p384_field();
let expected_p384 = BigUint::from_be_bytes(&hex_decode(
"fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe\
ffffffff0000000000000000ffffffff",
));
assert_eq!(p384, expected_p384, "NIST P-384");
let p224 = nist_p224_field();
let expected_p224 = BigUint::from_be_bytes(&hex_decode(
"ffffffffffffffffffffffffffffffff000000000000000000000001",
));
assert_eq!(p224, expected_p224, "NIST P-224");
let p192 = nist_p192_field();
let expected_p192 = BigUint::from_be_bytes(&hex_decode(
"fffffffffffffffffffffffffffffffeffffffffffffffff",
));
assert_eq!(p192, expected_p192, "NIST P-192");
let secp = secp256k1_field();
let expected_secp = BigUint::from_be_bytes(&hex_decode(
"fffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f",
));
assert_eq!(secp, expected_secp, "secp256k1");
let c25519 = curve25519_field();
let expected_c25519 = BigUint::from_be_bytes(&hex_decode(
"7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffed",
));
assert_eq!(c25519, expected_c25519, "curve25519");
let c448 = curve448_field();
let expected_c448 = BigUint::from_be_bytes(&hex_decode(
"fffffffffffffffffffffffffffffffffffffffffffffffffffffffe\
ffffffffffffffffffffffffffffffffffffffffffffffffffffffff",
));
assert_eq!(c448, expected_c448, "curve448");
let poly = poly1305_field();
let expected_poly = BigUint::from_be_bytes(&hex_decode(
"03fffffffffffffffffffffffffffffffb",
));
assert_eq!(poly, expected_poly, "poly1305");
let m521 = mersenne521();
let mut expected_m521 = BigUint::one();
expected_m521.shl_bits(521);
expected_m521 = expected_m521.sub_ref(&BigUint::one());
assert_eq!(m521, expected_m521, "mersenne521");
}
fn hex_decode(s: &str) -> Vec<u8> {
let cleaned: Vec<u8> = s.bytes().filter(|b| !b.is_ascii_whitespace()).collect();
assert!(cleaned.len().is_multiple_of(2), "hex must have even length");
cleaned
.chunks_exact(2)
.map(|pair| {
let hi = nibble(pair[0]);
let lo = nibble(pair[1]);
(hi << 4) | lo
})
.collect()
}
fn nibble(b: u8) -> u8 {
match b {
b'0'..=b'9' => b - b'0',
b'a'..=b'f' => b - b'a' + 10,
b'A'..=b'F' => b - b'A' + 10,
_ => panic!("non-hex char"),
}
}
#[test]
fn detect_kind_routes_each_known_prime_to_fast_path() {
let cases: &[(&str, BigUint)] = &[
("mersenne127", mersenne127()),
("mersenne521", mersenne521()),
("curve25519", curve25519_field()),
("poly1305", poly1305_field()),
("secp256k1", secp256k1_field()),
("curve448", curve448_field()),
("nist_p192", nist_p192_field()),
("nist_p224", nist_p224_field()),
("nist_p384", nist_p384_field()),
];
for (name, p) in cases {
let f = PrimeField::new_unchecked(p.clone());
let routed = matches!(
f.kind,
FieldKind::Mersenne127 | FieldKind::Reduction(_)
);
assert!(routed, "{name} fell through to Generic");
}
}
#[test]
fn nist_p256_routes_to_generic_when_prefer_fast_false() {
let f = PrimeField::new_unchecked(nist_p256_field());
assert!(matches!(f.kind, FieldKind::Generic),
"nist_p256 should route to Generic when prefer_fast is false");
}
#[test]
fn fuzz_mersenne127() {
full_prime_check("mersenne127", mersenne127(), 127, 0xC1);
}
#[test]
#[should_panic(expected = "zero coefficient")]
fn validate_rejects_zero_coefficient() {
static BAD: &[super::ReductionTerm] = &[super::ReductionTerm { offset: 0, coef: 0 }];
let params = super::ReductionParams {
k: 127,
terms: BAD,
p: mersenne127(),
name: "bad_zero_coef", prefer_fast: true,
};
super::validate_reduction_params(¶ms);
}
#[test]
#[should_panic(expected = "≥ k")]
fn validate_rejects_offset_at_or_above_k() {
static BAD: &[super::ReductionTerm] = &[super::ReductionTerm { offset: 127, coef: 1 }];
let params = super::ReductionParams {
k: 127,
terms: BAD,
p: mersenne127(),
name: "bad_offset", prefer_fast: true,
};
super::validate_reduction_params(¶ms);
}
#[test]
#[should_panic(expected = "δ must be positive")]
fn validate_rejects_nonpositive_delta() {
static BAD: &[super::ReductionTerm] = &[super::ReductionTerm { offset: 0, coef: -1 }];
let mut p = BigUint::one();
p.shl_bits(127);
p = p.add_ref(&BigUint::one());
let params = super::ReductionParams {
k: 127,
terms: BAD,
p,
name: "bad_negative_delta", prefer_fast: true,
};
super::validate_reduction_params(¶ms);
}
#[test]
#[should_panic(expected = "does not match 2^k − p")]
fn validate_rejects_mismatched_polynomial() {
static BAD: &[super::ReductionTerm] = &[super::ReductionTerm { offset: 0, coef: 5 }];
let params = super::ReductionParams {
k: 127,
terms: BAD,
p: mersenne127(),
name: "bad_polynomial_mismatch", prefer_fast: true,
};
super::validate_reduction_params(¶ms);
}
#[test]
fn unknown_modulus_falls_through_to_generic() {
let p = BigUint::from_u64(1_000_000_007);
let f = PrimeField::new_unchecked(p);
assert!(matches!(f.kind, FieldKind::Generic));
}
}