use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
use subtle::{Choice, ConditionallySelectable, ConstantTimeEq};
use zeroize::Zeroize;
use super::FieldImplementation;
pub type Limbs = [u32; 8];
type U256 = [u32; 8];
extern "C" {
pub fn fe25519_add_asm(result: *mut U256, left: *const U256, right: *const U256);
pub fn fe25519_mul_asm(result: *mut U256, left: *const U256, right: *const U256);
pub fn fe25519_square_asm(result: *mut U256, value: *const U256);
}
#[derive(Clone, Copy, Debug, Default, Zeroize)]
pub struct FieldElement(pub Limbs);
impl ConditionallySelectable for FieldElement {
fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
let mut selection = Self::default();
for i in 0..8 {
selection.0[i] = u32::conditional_select(&a.0[i], &b.0[i], choice);
}
selection
}
fn conditional_swap(a: &mut Self, b: &mut Self, choice: Choice) {
for (ai, bi) in a.0.iter_mut().zip(b.0.iter_mut()) {
u32::conditional_swap(ai, bi, choice);
}
}
}
impl FieldImplementation for FieldElement {
type Limbs = Limbs;
const ZERO: Self = Self([0, 0, 0, 0, 0, 0, 0, 0]);
const ONE: Self = Self([1, 0, 0, 0, 0, 0, 0, 0]);
const D: Self = Self([
0x135978a3, 0x75eb4dca, 0x4141d8ab, 0x00700a4d, 0x7779e898, 0x8cc74079, 0x2b6ffe73,
0x52036cee,
]);
const D2: Self = Self([
0x26b2f159, 0xebd69b94, 0x8283b156, 0x00e0149a, 0xeef3d130, 0x198e80f2, 0x56dffce7,
0x2406d9dc,
]);
const EDWARDS_BASEPOINT_X: Self = Self([
0x8f25d51a, 0xc9562d60, 0x9525a7b2, 0x692cc760, 0xfdd6dc5c, 0xc0a4e231, 0xcd6e53fe,
0x216936d3,
]);
const EDWARDS_BASEPOINT_Y: Self = Self([
0x6666_6658,
0x6666_6666,
0x6666_6666,
0x6666_6666,
0x6666_6666,
0x6666_6666,
0x6666_6666,
0x6666_6666,
]);
const I: Self = Self([
0x4a0ea0b0, 0xc4ee1b27, 0xad2fe478, 0x2f431806, 0x3dfbd7a7, 0x2b4d0099, 0x4fc1df0b,
0x2b832480,
]);
const APLUS2_OVER_FOUR: Self = Self([121666, 0, 0, 0, 0, 0, 0, 0]);
const MONTGOMERY_BASEPOINT_U: Self = Self([9, 0, 0, 0, 0, 0, 0, 0]);
fn to_bytes(&self) -> [u8; 32] {
let mut fe = *self;
FieldElement::reduce_completely(&mut fe);
unsafe { core::mem::transmute(fe.0) }
}
fn from_bytes_unchecked(bytes: &[u8; 32]) -> FieldElement {
let mut limbs: U256 = unsafe { core::mem::transmute(*bytes) };
limbs[7] &= 0x7fff_ffff;
FieldElement(limbs)
}
fn inverse(&self) -> FieldElement {
let mut inverse = *self;
for i in (0..=253).rev() {
inverse = inverse.squared();
if i != 2 && i != 4 {
inverse = &inverse * self;
}
}
inverse
}
fn squared(&self) -> FieldElement {
let mut square = U256::default();
unsafe {
fe25519_square_asm(&mut square, &self.0);
}
FieldElement(square)
}
fn pow2523(&self) -> FieldElement {
let mut sqrt = *self;
for i in (0..=250).rev() {
sqrt = sqrt.squared();
if i != 1 {
sqrt = &sqrt * self;
}
}
sqrt
}
}
impl ConstantTimeEq for FieldElement {
fn ct_eq(&self, other: &Self) -> Choice {
let canonical_self = self.to_bytes();
let canonical_other = other.to_bytes();
canonical_self.ct_eq(&canonical_other)
}
}
impl PartialEq for FieldElement {
fn eq(&self, other: &Self) -> bool {
bool::from(self.ct_eq(other))
}
}
impl<'a, 'b> Add<&'b FieldElement> for &'a FieldElement {
type Output = FieldElement;
fn add(self, other: &'b FieldElement) -> FieldElement {
let mut sum = U256::default();
unsafe {
fe25519_add_asm(&mut sum, &self.0, &other.0);
}
FieldElement(sum)
}
}
impl<'b> AddAssign<&'b FieldElement> for FieldElement {
fn add_assign(&mut self, other: &'b FieldElement) {
*self = (self as &FieldElement) + other;
}
}
impl<'a> Neg for &'a FieldElement {
type Output = FieldElement;
fn neg(self) -> FieldElement {
&FieldElement::ZERO - self
}
}
impl<'a, 'b> Sub<&'b FieldElement> for &'a FieldElement {
type Output = FieldElement;
fn sub(self, other: &'b FieldElement) -> FieldElement {
let mut difference = U256::default();
let mut accu: i64;
accu = self.0[7] as i64;
accu -= other.0[7] as i64;
difference[7] = (accu as u32) | 0x8000_0000;
accu = ((((accu >> 31) as i32) - 1) * 19) as i64;
for (i, difference) in difference.iter_mut().take(7).enumerate() {
accu += self.0[i] as i64;
accu -= other.0[i] as i64;
*difference = accu as u32;
accu >>= 32;
}
accu += difference[7] as i64;
difference[7] = accu as u32;
FieldElement(difference)
}
}
impl<'b> SubAssign<&'b FieldElement> for FieldElement {
fn sub_assign(&mut self, other: &'b FieldElement) {
*self = (self as &FieldElement) - other;
}
}
impl<'a, 'b> Mul<&'b FieldElement> for &'a FieldElement {
type Output = FieldElement;
fn mul(self, other: &'b FieldElement) -> FieldElement {
let mut product = U256::default();
unsafe {
fe25519_mul_asm(&mut product, &self.0, &other.0);
}
FieldElement(product)
}
}
impl<'b> MulAssign<&'b FieldElement> for FieldElement {
fn mul_assign(&mut self, other: &'b FieldElement) {
*self = (self as &FieldElement) * other;
}
}
impl FieldElement {
pub fn reduce_completely(value: &mut FieldElement) {
let guess: u32 = value.0[7] >> 31;
let mut accu: u64 = (guess as u64) * 19 + 19;
for i in 0..7 {
accu += value.0[i] as u64;
accu >>= 32;
}
accu += value.0[7] as u64;
let answer: u32 = (accu >> 31) as u32;
accu = answer as u64 * 19;
for i in 0..7 {
accu += value.0[i] as u64;
value.0[i] = accu as u32;
accu >>= 32;
}
accu += value.0[7] as u64;
value.0[7] = (accu as u32) & 0x7fff_ffff;
}
}
#[cfg(test)]
mod tests {
use super::FieldElement;
use crate::field::FieldImplementation;
use subtle::ConstantTimeEq;
#[test]
fn test_one_plus_one() {
let one = FieldElement::ONE;
let two = &one + &one;
let expected = FieldElement([2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]);
assert_eq!(two.0, expected.0);
assert!(bool::from(two.ct_eq(&expected)))
}
#[test]
fn test_one_times_zero() {
let one = FieldElement::ONE;
let zero = FieldElement::ZERO;
let result = &one * &zero;
assert_eq!(result.0, zero.0);
assert!(bool::from(result.ct_eq(&zero)))
}
#[test]
fn test_two_times_three_is_six() {
let one = FieldElement::ONE;
let two = &one + &one;
let three = &two + &one;
let two_times_three = &two * &three;
let six = (1..=6).fold(FieldElement::ZERO, |partial_sum, _| {
&partial_sum + &FieldElement::ONE
});
assert_eq!(two_times_three.to_bytes(), six.to_bytes());
assert!(bool::from(two_times_three.ct_eq(&six)));
}
#[test]
fn test_negation() {
let d2 = FieldElement::D2;
let minus_d2 = -&d2;
let maybe_zero = &d2 + &minus_d2;
assert_eq!(FieldElement::ZERO.to_bytes(), maybe_zero.to_bytes());
}
#[test]
fn test_inversion() {
let d2 = FieldElement::D2;
let maybe_inverse = d2.inverse();
let maybe_one = &d2 * &maybe_inverse;
assert_eq!(maybe_one.to_bytes(), FieldElement::ONE.to_bytes());
assert!(bool::from(maybe_one.ct_eq(&FieldElement::ONE)));
assert_eq!(maybe_one, FieldElement::ONE);
}
}