use crate::field::element::FieldElement;
use crate::field::errors::FieldError;
use crate::field::traits::IsField;
#[cfg(feature = "lambdaworks-serde-binary")]
use crate::traits::ByteConversion;
use core::fmt::Debug;
use core::marker::PhantomData;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct QuadraticExtensionField<T> {
phantom: PhantomData<T>,
}
pub type QuadraticExtensionFieldElement<T> = FieldElement<QuadraticExtensionField<T>>;
pub trait HasQuadraticNonResidue {
type BaseField: IsField;
fn residue() -> FieldElement<Self::BaseField>;
}
impl<Q> FieldElement<QuadraticExtensionField<Q>>
where
Q: Clone + Debug + HasQuadraticNonResidue,
{
pub fn conjugate(&self) -> Self {
let [a, b] = self.value();
Self::new([a.clone(), -b])
}
}
#[cfg(feature = "lambdaworks-serde-binary")]
impl<F> ByteConversion for [FieldElement<F>; 2]
where
F: IsField,
{
#[cfg(feature = "std")]
fn to_bytes_be(&self) -> Vec<u8> {
unimplemented!()
}
#[cfg(feature = "std")]
fn to_bytes_le(&self) -> Vec<u8> {
unimplemented!()
}
fn from_bytes_be(_bytes: &[u8]) -> Result<Self, crate::errors::ByteConversionError>
where
Self: Sized,
{
unimplemented!()
}
fn from_bytes_le(_bytes: &[u8]) -> Result<Self, crate::errors::ByteConversionError>
where
Self: Sized,
{
unimplemented!()
}
}
impl<Q> IsField for QuadraticExtensionField<Q>
where
Q: Clone + Debug + HasQuadraticNonResidue,
{
type BaseType = [FieldElement<Q::BaseField>; 2];
fn add(
a: &[FieldElement<Q::BaseField>; 2],
b: &[FieldElement<Q::BaseField>; 2],
) -> [FieldElement<Q::BaseField>; 2] {
[&a[0] + &b[0], &a[1] + &b[1]]
}
fn mul(
a: &[FieldElement<Q::BaseField>; 2],
b: &[FieldElement<Q::BaseField>; 2],
) -> [FieldElement<Q::BaseField>; 2] {
let q = Q::residue();
let a0b0 = &a[0] * &b[0];
let a1b1 = &a[1] * &b[1];
let z = (&a[0] + &a[1]) * (&b[0] + &b[1]);
[&a0b0 + &a1b1 * q, z - a0b0 - a1b1]
}
fn square(a: &[FieldElement<Q::BaseField>; 2]) -> [FieldElement<Q::BaseField>; 2] {
let [a0, a1] = a;
let v0 = a0 * a1;
let c0 = (a0 + a1) * (a0 + Q::residue() * a1) - &v0 - Q::residue() * &v0;
let c1 = &v0 + &v0;
[c0, c1]
}
fn sub(
a: &[FieldElement<Q::BaseField>; 2],
b: &[FieldElement<Q::BaseField>; 2],
) -> [FieldElement<Q::BaseField>; 2] {
[&a[0] - &b[0], &a[1] - &b[1]]
}
fn neg(a: &[FieldElement<Q::BaseField>; 2]) -> [FieldElement<Q::BaseField>; 2] {
[-&a[0], -&a[1]]
}
fn inv(
a: &[FieldElement<Q::BaseField>; 2],
) -> Result<[FieldElement<Q::BaseField>; 2], FieldError> {
let inv_norm = (a[0].pow(2_u64) - Q::residue() * a[1].pow(2_u64)).inv()?;
Ok([&a[0] * &inv_norm, -&a[1] * inv_norm])
}
fn div(
a: &[FieldElement<Q::BaseField>; 2],
b: &[FieldElement<Q::BaseField>; 2],
) -> [FieldElement<Q::BaseField>; 2] {
Self::mul(a, &Self::inv(b).unwrap())
}
fn eq(a: &[FieldElement<Q::BaseField>; 2], b: &[FieldElement<Q::BaseField>; 2]) -> bool {
a[0] == b[0] && a[1] == b[1]
}
fn zero() -> [FieldElement<Q::BaseField>; 2] {
[FieldElement::zero(), FieldElement::zero()]
}
fn one() -> [FieldElement<Q::BaseField>; 2] {
[FieldElement::one(), FieldElement::zero()]
}
fn from_u64(x: u64) -> Self::BaseType {
[FieldElement::from(x), FieldElement::zero()]
}
fn from_base_type(x: [FieldElement<Q::BaseField>; 2]) -> [FieldElement<Q::BaseField>; 2] {
x
}
}
impl<Q: Clone + Debug + HasQuadraticNonResidue> FieldElement<QuadraticExtensionField<Q>> {}
#[cfg(test)]
mod tests {
use crate::field::fields::u64_prime_field::{U64FieldElement, U64PrimeField};
const ORDER_P: u64 = 59;
use super::*;
#[derive(Debug, Clone)]
struct MyQuadraticNonResidue;
impl HasQuadraticNonResidue for MyQuadraticNonResidue {
type BaseField = U64PrimeField<ORDER_P>;
fn residue() -> FieldElement<U64PrimeField<ORDER_P>> {
-FieldElement::one()
}
}
type FE = U64FieldElement<ORDER_P>;
type MyFieldExtensionBackend = QuadraticExtensionField<MyQuadraticNonResidue>;
#[allow(clippy::upper_case_acronyms)]
type FEE = FieldElement<MyFieldExtensionBackend>;
#[test]
fn test_add_1() {
let a = FEE::new([FE::new(0), FE::new(3)]);
let b = FEE::new([-FE::new(2), FE::new(8)]);
let expected_result = FEE::new([FE::new(57), FE::new(11)]);
assert_eq!(a + b, expected_result);
}
#[test]
fn test_add_2() {
let a = FEE::new([FE::new(12), FE::new(5)]);
let b = FEE::new([-FE::new(4), FE::new(2)]);
let expected_result = FEE::new([FE::new(8), FE::new(7)]);
assert_eq!(a + b, expected_result);
}
#[test]
fn test_sub_1() {
let a = FEE::new([FE::new(0), FE::new(3)]);
let b = FEE::new([-FE::new(2), FE::new(8)]);
let expected_result = FEE::new([FE::new(2), FE::new(54)]);
assert_eq!(a - b, expected_result);
}
#[test]
fn test_sub_2() {
let a = FEE::new([FE::new(12), FE::new(5)]);
let b = FEE::new([-FE::new(4), FE::new(2)]);
let expected_result = FEE::new([FE::new(16), FE::new(3)]);
assert_eq!(a - b, expected_result);
}
#[test]
fn test_mul_1() {
let a = FEE::new([FE::new(0), FE::new(3)]);
let b = FEE::new([-FE::new(2), FE::new(8)]);
let expected_result = FEE::new([FE::new(35), FE::new(53)]);
assert_eq!(a * b, expected_result);
}
#[test]
fn test_mul_2() {
let a = FEE::new([FE::new(12), FE::new(5)]);
let b = FEE::new([-FE::new(4), FE::new(2)]);
let expected_result = FEE::new([FE::new(1), FE::new(4)]);
assert_eq!(a * b, expected_result);
}
#[test]
fn test_div_1() {
let a = FEE::new([FE::new(0), FE::new(3)]);
let b = FEE::new([-FE::new(2), FE::new(8)]);
let expected_result = FEE::new([FE::new(42), FE::new(19)]);
assert_eq!(a / b, expected_result);
}
#[test]
fn test_div_2() {
let a = FEE::new([FE::new(12), FE::new(5)]);
let b = FEE::new([-FE::new(4), FE::new(2)]);
let expected_result = FEE::new([FE::new(4), FE::new(45)]);
assert_eq!(a / b, expected_result);
}
#[test]
fn test_pow_1() {
let a = FEE::new([FE::new(0), FE::new(3)]);
let b: u64 = 5;
let expected_result = FEE::new([FE::new(0), FE::new(7)]);
assert_eq!(a.pow(b), expected_result);
}
#[test]
fn test_pow_2() {
let a = FEE::new([FE::new(12), FE::new(5)]);
let b: u64 = 8;
let expected_result = FEE::new([FE::new(52), FE::new(35)]);
assert_eq!(a.pow(b), expected_result);
}
#[test]
fn test_inv_1() {
let a = FEE::new([FE::new(0), FE::new(3)]);
let expected_result = FEE::new([FE::new(0), FE::new(39)]);
assert_eq!(a.inv().unwrap(), expected_result);
}
#[test]
fn test_inv() {
let a = FEE::new([FE::new(12), FE::new(5)]);
let expected_result = FEE::new([FE::new(28), FE::new(8)]);
assert_eq!(a.inv().unwrap(), expected_result);
}
#[test]
fn test_conjugate() {
let a = FEE::new([FE::new(12), FE::new(5)]);
let expected_result = FEE::new([FE::new(12), -FE::new(5)]);
assert_eq!(a.conjugate(), expected_result);
}
}