use discrete_logarithm::discrete_log_with_order;
use rug::{ops::Pow, Integer};
use std::ops::{Add, Div, Mul, Neg, Sub};
#[derive(Debug, Clone)]
pub struct PrimeField {
modulus: Integer,
}
impl PrimeField {
pub fn new(modulus: Integer) -> Self {
Self { modulus }
}
pub fn modulus(&self) -> &Integer {
&self.modulus
}
pub fn element(&self, value: Integer) -> FieldElement<'_> {
let value = value % &self.modulus;
FieldElement { value, field: self }
}
pub fn zero(&self) -> FieldElement<'_> {
self.element(Integer::from(0))
}
pub fn one(&self) -> FieldElement<'_> {
self.element(Integer::from(1))
}
}
#[derive(Debug, Clone)]
pub struct FieldElement<'a> {
value: Integer,
field: &'a PrimeField,
}
impl<'a> FieldElement<'a> {
pub fn value(&self) -> &Integer {
&self.value
}
pub fn field(&self) -> &PrimeField {
self.field
}
pub fn into_value(self) -> Integer {
self.value
}
pub fn inv(&self) -> Option<FieldElement<'a>> {
self.value
.clone()
.invert(self.field.modulus())
.ok()
.map(|inv| self.field.element(inv))
}
pub fn pow(&self, exp: &Integer) -> FieldElement<'a> {
let result = self
.value
.clone()
.pow_mod(exp, self.field.modulus())
.unwrap();
self.field.element(result)
}
pub fn pow_u32(&self, exp: u32) -> FieldElement<'a> {
self.pow(&Integer::from(exp))
}
pub fn is_zero(&self) -> bool {
self.value == 0
}
pub fn is_one(&self) -> bool {
self.value == 1
}
}
impl<'a> PartialEq for FieldElement<'a> {
fn eq(&self, other: &Self) -> bool {
self.value == other.value
}
}
impl<'a> Eq for FieldElement<'a> {}
impl<'a> Add for FieldElement<'a> {
type Output = Self;
fn add(self, other: Self) -> Self {
let result = (self.value + other.value) % self.field.modulus();
self.field.element(result)
}
}
impl<'a> Add for &FieldElement<'a> {
type Output = FieldElement<'a>;
fn add(self, other: Self) -> FieldElement<'a> {
let result = (self.value.clone() + other.value.clone()) % self.field.modulus();
self.field.element(result)
}
}
impl<'a> Sub for FieldElement<'a> {
type Output = Self;
fn sub(self, other: Self) -> Self {
let mut result = (self.value - other.value) % self.field.modulus();
if result < 0 {
result += self.field.modulus();
}
self.field.element(result)
}
}
impl<'a> Sub for &FieldElement<'a> {
type Output = FieldElement<'a>;
fn sub(self, other: Self) -> FieldElement<'a> {
let mut result = (self.value.clone() - other.value.clone()) % self.field.modulus();
if result < 0 {
result += self.field.modulus();
}
self.field.element(result)
}
}
impl<'a> Mul for FieldElement<'a> {
type Output = Self;
fn mul(self, other: Self) -> Self {
let result = (self.value * other.value) % self.field.modulus();
self.field.element(result)
}
}
impl<'a> Mul for &FieldElement<'a> {
type Output = FieldElement<'a>;
fn mul(self, other: Self) -> FieldElement<'a> {
let result = (self.value.clone() * other.value.clone()) % self.field.modulus();
self.field.element(result)
}
}
impl<'a> Div for FieldElement<'a> {
type Output = Self;
#[allow(clippy::suspicious_arithmetic_impl)]
fn div(self, other: Self) -> Self {
let inv = other
.inv()
.expect("Division by zero or non-invertible element");
self * inv
}
}
impl<'a> Div for &FieldElement<'a> {
type Output = FieldElement<'a>;
#[allow(clippy::suspicious_arithmetic_impl)]
fn div(self, other: Self) -> FieldElement<'a> {
let inv = other
.inv()
.expect("Division by zero or non-invertible element");
self * &inv
}
}
impl<'a> Neg for FieldElement<'a> {
type Output = Self;
fn neg(self) -> Self {
let result = self.field.modulus() - self.value;
self.field.element(result)
}
}
impl<'a> Neg for &FieldElement<'a> {
type Output = FieldElement<'a>;
fn neg(self) -> FieldElement<'a> {
let result = self.field.modulus() - self.value.clone();
self.field.element(result)
}
}
#[derive(Debug, Clone)]
pub struct QuadraticExtension<'a> {
real: Integer, imag: Integer, field: &'a PrimeField, omega: Integer, }
impl<'a> QuadraticExtension<'a> {
pub fn new(field: &'a PrimeField, real: Integer, imag: Integer, omega: Integer) -> Self {
let real = real % field.modulus();
let imag = imag % field.modulus();
Self {
real,
imag,
field,
omega,
}
}
#[allow(dead_code)]
pub fn field(&self) -> &'a PrimeField {
self.field
}
pub fn real(&self) -> &Integer {
&self.real
}
#[allow(dead_code)]
pub fn imag(&self) -> &Integer {
&self.imag
}
#[allow(dead_code)]
pub fn omega(&self) -> &Integer {
&self.omega
}
pub fn mul(&self, other: &Self) -> Self {
let ac = (self.real.clone() * &other.real) % self.field.modulus();
let bd = (self.imag.clone() * &other.imag) % self.field.modulus();
let bd_omega = (bd * &self.omega) % self.field.modulus();
let real_part = (ac + bd_omega) % self.field.modulus();
let ad = (self.real.clone() * &other.imag) % self.field.modulus();
let bc = (self.imag.clone() * &other.real) % self.field.modulus();
let imag_part = (ad + bc) % self.field.modulus();
Self::new(self.field, real_part, imag_part, self.omega.clone())
}
pub fn square(&self) -> Self {
let a2 = (self.real.clone() * &self.real) % self.field.modulus();
let b2 = (self.imag.clone() * &self.imag) % self.field.modulus();
let b2_omega = (b2 * &self.omega) % self.field.modulus();
let real_part = (a2 + b2_omega) % self.field.modulus();
let two = Integer::from(2);
let temp = (two * &self.real) % self.field.modulus();
let imag_part = (temp * &self.imag) % self.field.modulus();
Self::new(self.field, real_part, imag_part, self.omega.clone())
}
pub fn pow(&self, exp: &Integer) -> Self {
if exp == &0 {
return Self::new(
self.field,
Integer::from(1),
Integer::from(0),
self.omega.clone(),
);
}
let mut result = Self::new(
self.field,
Integer::from(1),
Integer::from(0),
self.omega.clone(),
);
let mut base = self.clone();
let mut e = exp.clone();
while e > 0 {
if e.is_odd() {
result = result.mul(&base);
}
base = base.square();
e >>= 1;
}
result
}
}
pub fn rth_roots(field: &PrimeField, delta: &Integer, r: u32) -> Vec<Integer> {
if r == 0 || r > 10000 || delta == &Integer::from(0) {
return Vec::new();
}
let p = field.modulus();
let pm1: Integer = p.clone() - 1;
let r_int = Integer::from(r);
if pm1.clone() % &r_int != 0 {
return Vec::new();
}
let mut t = 0u32;
let mut s = pm1.clone();
while s.clone() % &r_int == 0 {
t += 1;
s /= &r_int;
}
if t == 0 {
return Vec::new();
}
let delta_elem = field.element(delta.clone());
let exp_omega = pm1.clone() / &r_int;
let omega = {
let mut omega = field.one();
for candidate in 2..1000 {
let g = field.element(Integer::from(candidate));
omega = g.pow(&exp_omega);
if !omega.is_one() {
break;
}
}
omega
};
if t == 1 {
let inv_r = match r_int.clone().invert(&s) {
Ok(inv) => inv,
Err(_) => return Vec::new(),
};
let root = delta_elem.pow(&inv_r);
let mut roots = Vec::with_capacity(r as usize);
let mut current = root;
for _ in 0..r {
roots.push(current.value().clone());
current = ¤t * ω
}
return roots;
}
let p_gen = {
let exp_test = pm1.clone() / &r_int;
let mut p_gen = field.one();
for candidate in 2..1000 {
let g = field.element(Integer::from(candidate));
let test = g.pow(&exp_test);
if !test.is_one() {
p_gen = g;
break;
}
}
p_gen
};
let mut k = Integer::from(1);
while (k.clone() * &s + Integer::from(1)) % &r_int != 0 {
k += 1;
}
let alpha = (k * &s + Integer::from(1)) / &r_int;
let r_power_reduced = r_int.clone().pow_mod(&Integer::from(t - 1), &pm1).unwrap();
let exp_a = r_power_reduced.clone() * &s % &pm1;
let a = p_gen.pow(&exp_a);
let mut b = delta_elem.pow(&(r_int.clone() * &alpha - 1));
let mut c = p_gen.pow(&s);
let mut h = field.one();
if a.is_one() {
return Vec::new();
}
for i in 1..t {
let exp_d = r_int.clone().pow(t - 1 - i) % &pm1;
let d = b.pow(&exp_d);
let j = if d.is_one() {
Integer::from(0)
} else {
let d_inv = field.element(d.value().clone().invert(p).unwrap());
match discrete_log_with_order(d_inv.value(), a.value(), p, &r_int) {
Ok(j_val) => j_val,
Err(_) => {
if r <= 10000 {
let mut found = None;
let mut a_power = field.one();
for k in 0..r {
if a_power.value() == d_inv.value() {
found = Some(Integer::from(k));
break;
}
a_power = &a_power * &a;
}
found.unwrap_or_else(|| Integer::from(0))
} else {
Integer::from(0)
}
}
}
};
let c_r = c.pow(&r_int);
b = &b * &c_r.pow(&j);
h = &h * &c.pow(&j);
c = c_r;
}
let root = &delta_elem.pow(&alpha) * &h;
let mut roots = Vec::with_capacity(r as usize);
let mut current = root;
for _ in 0..r {
roots.push(current.value().clone());
current = ¤t * ω
}
roots
}
#[cfg(test)]
mod tests {
use super::*;
use std::{collections::HashSet, str::FromStr};
#[test]
fn test_basic_arithmetic() {
let p = Integer::from(17);
let field = PrimeField::new(p);
let a = field.element(Integer::from(5));
let b = field.element(Integer::from(12));
let sum = &a + &b;
assert_eq!(sum.value(), &Integer::from(0));
let diff = &a - &b;
assert_eq!(diff.value(), &Integer::from(10));
let prod = &a * &b;
assert_eq!(prod.value(), &Integer::from(9));
}
#[test]
fn test_inversion() {
let p = Integer::from(17);
let field = PrimeField::new(p);
let a = field.element(Integer::from(5));
let a_inv = a.inv().unwrap();
assert_eq!(a_inv.value(), &Integer::from(7));
let prod = &a * &a_inv;
assert!(prod.is_one());
}
#[test]
fn test_power() {
let p = Integer::from(17);
let field = PrimeField::new(p);
let a = field.element(Integer::from(3));
let result = a.pow(&Integer::from(4));
assert_eq!(result.value(), &Integer::from(13));
}
#[test]
fn test_division() {
let p = Integer::from(17);
let field = PrimeField::new(p);
let a = field.element(Integer::from(10));
let b = field.element(Integer::from(5));
let quot = &a / &b;
assert_eq!(quot.value(), &Integer::from(2));
}
#[test]
fn test_large_field() {
let p = Integer::from_str("340282366920938463463374607431768211297").unwrap();
let field = PrimeField::new(p);
let a = field.element(Integer::from_str("123456789012345678901234567890").unwrap());
let b = field.element(Integer::from_str("987654321098765432109876543210").unwrap());
let sum = &a + &b;
let prod = &a * &b;
assert!(sum.value() < field.modulus());
assert!(prod.value() < field.modulus());
}
#[test]
fn test_quadratic_extension_basic() {
let p = Integer::from(7);
let field = PrimeField::new(p);
let omega = Integer::from(2);
let a = QuadraticExtension::new(&field, Integer::from(3), Integer::from(4), omega.clone());
let result = a.square();
assert_eq!(result.real(), &Integer::from(6));
assert_eq!(result.imag(), &Integer::from(3));
}
#[test]
fn test_quadratic_extension_mul() {
let p = Integer::from(7);
let field = PrimeField::new(p);
let omega = Integer::from(2);
let a = QuadraticExtension::new(&field, Integer::from(1), Integer::from(2), omega.clone());
let b = QuadraticExtension::new(&field, Integer::from(3), Integer::from(4), omega.clone());
let result = a.mul(&b);
assert_eq!(result.real(), &Integer::from(5));
assert_eq!(result.imag(), &Integer::from(3));
}
#[test]
fn test_quadratic_extension_pow() {
let p = Integer::from(7);
let field = PrimeField::new(p);
let omega = Integer::from(2);
let a = QuadraticExtension::new(&field, Integer::from(2), Integer::from(1), omega.clone());
let result = a.pow(&Integer::from(0));
assert_eq!(result.real(), &Integer::from(1));
assert_eq!(result.imag(), &Integer::from(0));
let result = a.pow(&Integer::from(1));
assert_eq!(result.real(), &Integer::from(2));
assert_eq!(result.imag(), &Integer::from(1));
}
#[test]
fn test_rth_roots() {
let q = Integer::from_str("9908484735485245740582755998843475068910570989512225739800304203500256711207262150930812622460031920899674919818007279858208368349928684334780223996774347").unwrap();
let c = Integer::from_str("7267288183214469410349447052665186833632058119533973432573869246434984462336560480880459677870106195135869371300420762693116774837763418518542884912967719").unwrap();
let e = 21;
let field = PrimeField::new(q.clone());
let roots = rth_roots(&field, &c, e);
let unique_roots: HashSet<_> = roots.iter().collect();
assert_eq!(unique_roots.len(), 7);
for root in &roots {
let check = root.clone().pow_mod(&Integer::from(e), &q).unwrap();
assert_eq!(check, c, "root^21 mod q should equal c");
}
}
}