use core::ops::Add;
use std::io::Cursor;
use std::ops::{Mul, Sub};
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
use num_bigint::BigUint;
use num_traits::Num;
use crate::sm2::error::{Sm2Error, Sm2Result};
use crate::sm2::p256_ecc::P256C_PARAMS;
use crate::sm2::{FeOperation};
use crate::{forward_ref_ref_binop, forward_ref_val_binop, forward_val_val_binop};
pub type Fe = [u32; 8];
pub trait Conversion {
fn fe_to_bigunit(&self) -> BigUint;
fn bigunit_fe(&self) -> Fe;
}
pub const ECC_P: Fe = [
0xffff_fffe,
0xffff_ffff,
0xffff_ffff,
0xffff_ffff,
0xffff_ffff,
0x0000_0000,
0xffff_ffff,
0xffff_ffff,
];
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Debug)]
pub struct FieldElement {
pub(crate) inner: Fe,
}
impl FieldElement {
pub fn new(x: Fe) -> FieldElement {
FieldElement { inner: x }
}
pub fn from_slice(x: &[u32]) -> FieldElement {
let mut arr: Fe = [0; 8];
arr.copy_from_slice(&x[..]);
FieldElement::new(arr)
}
#[inline]
pub fn from_number(x: u64) -> FieldElement {
let mut arr: Fe = [0; 8];
arr[7] = (x & 0xffff_ffff) as u32;
arr[6] = (x >> 32) as u32;
FieldElement { inner: arr }
}
#[inline]
pub fn to_bytes_be(&self) -> Vec<u8> {
let mut ret: Vec<u8> = Vec::new();
for i in 0..8 {
ret.write_u32::<BigEndian>(self.inner[i]).unwrap();
}
ret
}
#[inline]
pub fn from_bytes_be(bytes: &[u8]) -> Sm2Result<FieldElement> {
if bytes.len() != 32 {
return Err(Sm2Error::InvalidFieldLen);
}
let mut elem = FieldElement::zero();
let mut c = Cursor::new(bytes);
for i in 0..8 {
elem.inner[i] = c.read_u32::<BigEndian>().unwrap();
}
Ok(elem)
}
pub fn to_biguint(&self) -> BigUint {
let v = self.to_bytes_be();
BigUint::from_bytes_be(&v[..])
}
pub fn from_biguint(bi: &BigUint) -> Sm2Result<FieldElement> {
let v = bi.to_bytes_be();
let mut num_v = [0; 32];
num_v[32 - v.len()..32].copy_from_slice(&v[..]);
FieldElement::from_bytes_be(&num_v[..])
}
pub fn sqrt(&self) -> Sm2Result<FieldElement> {
let u = BigUint::from_str_radix(
"28948022302589062189105086303505223191562588497981047863605298483322421248000",
10,
)
.unwrap();
let y = self.modpow(&u);
let z = &y.square();
if z == self {
Ok(y)
} else {
Err(Sm2Error::FieldSqrtError)
}
}
#[inline]
pub fn to_str_radix(&self, radix: u32) -> String {
self.to_biguint().to_str_radix(radix)
}
pub fn zero() -> FieldElement {
FieldElement::new([0; 8])
}
pub fn one() -> FieldElement {
FieldElement::from_number(1)
}
pub fn is_even(&self) -> bool {
self.inner[7] & 0x01 == 0
}
pub fn is_zero(&self) -> bool {
self.inner == [0; 8]
}
pub fn is_one(&self) -> bool {
self.inner[7] == 1
}
pub fn square(&self) -> FieldElement {
self.clone() * self.clone()
}
pub fn double(&self) -> FieldElement {
self.clone() + self.clone()
}
pub fn modpow(&self, exponent: &BigUint) -> Self {
let u = FieldElement::from_biguint(exponent).unwrap();
let mut q0 = FieldElement::from_number(1);
let mut q1 = *self;
let mut i = 0;
while i < 256 {
let index = i as usize / 32;
let bit = 31 - i as usize % 32;
let sum = &q0 * &q1;
if (u.inner[index] >> bit) & 0x01 == 0 {
q1 = sum;
q0 = q0.square();
} else {
q0 = sum;
q1 = q1.square();
}
i += 1;
}
q0
}
pub fn modinv(&self) -> FieldElement {
let ecc_p = &P256C_PARAMS.p;
let ret = self.inner.inv(&ecc_p.inner);
FieldElement::new(ret)
}
}
forward_val_val_binop!(impl Add for FieldElement, add);
forward_ref_ref_binop!(impl Add for FieldElement, add);
forward_ref_val_binop!(impl Add for FieldElement, add);
impl<'a> Add<&'a FieldElement> for FieldElement {
type Output = FieldElement;
fn add(mut self, rhs: &FieldElement) -> Self::Output {
self.inner = self.inner.mod_add(&rhs.inner, &ECC_P);
self
}
}
forward_val_val_binop!(impl Sub for FieldElement, sub);
forward_ref_ref_binop!(impl Sub for FieldElement, sub);
forward_ref_val_binop!(impl Sub for FieldElement, sub);
impl<'a> Sub<&'a FieldElement> for FieldElement {
type Output = FieldElement;
fn sub(mut self, rhs: &'a FieldElement) -> Self::Output {
self.inner = self.inner.mod_sub(&rhs.inner, &ECC_P);
self
}
}
forward_val_val_binop!(impl Mul for FieldElement, mul);
forward_ref_ref_binop!(impl Mul for FieldElement, mul);
forward_ref_val_binop!(impl Mul for FieldElement, mul);
impl<'a> Mul<&'a FieldElement> for FieldElement {
type Output = FieldElement;
fn mul(mut self, rhs: &'a FieldElement) -> Self::Output {
self.inner = self.inner.mod_mul(&rhs.inner, &ECC_P);
self
}
}
impl Add<u64> for FieldElement {
type Output = FieldElement;
fn add(mut self, rhs: u64) -> Self::Output {
self.inner = self
.inner
.mod_add(&FieldElement::from_number(rhs).inner, &ECC_P);
self
}
}
impl Mul<u64> for FieldElement {
type Output = FieldElement;
fn mul(mut self, rhs: u64) -> Self::Output {
self.inner = self
.inner
.mod_mul(&FieldElement::from_number(rhs).inner, &ECC_P);
self
}
}
impl<'a> Mul<u64> for &'a FieldElement {
type Output = FieldElement;
fn mul(self, rhs: u64) -> Self::Output {
let mut s = self.clone();
s.inner = s
.inner
.mod_mul(&FieldElement::from_number(rhs).inner, &ECC_P);
s
}
}
impl Default for FieldElement {
#[inline]
fn default() -> FieldElement {
FieldElement {
inner: [0; 8],
}
}
}