use crate::key::public::PublicKey;
use crate::limb::{Limb, LIMB_LENGTH, ONE};
use crate::norop::{norop_limbs_equal_with, norop_limbs_less_than};
use crate::sm2p256::{
add_mod, base_point_mul, inv_sqr, mont_pro, point_add, point_mul, scalar_add_mod, scalar_inv,
scalar_mont_pro, scalar_sub_mod, CURVE_PARAMS,
};
use core::marker::PhantomData;
#[derive(Copy, Clone)]
pub enum Unencoded {}
#[derive(Copy, Clone)]
pub enum R {}
#[allow(clippy::upper_case_acronyms)]
#[derive(Copy, Clone)]
pub enum RR {}
#[derive(Copy, Clone)]
pub enum RInverse {}
pub trait Encoding {}
impl Encoding for RR {}
impl Encoding for R {}
impl Encoding for Unencoded {}
impl Encoding for RInverse {}
pub trait ReductionEncoding {
type Output: Encoding;
}
impl ReductionEncoding for RR {
type Output = R;
}
impl ReductionEncoding for R {
type Output = Unencoded;
}
impl ReductionEncoding for Unencoded {
type Output = RInverse;
}
pub trait ProductEncoding {
type Output: Encoding;
}
impl<E: ReductionEncoding> ProductEncoding for (Unencoded, E) {
type Output = E::Output;
}
impl<E: Encoding> ProductEncoding for (R, E) {
type Output = E;
}
impl<E: ReductionEncoding> ProductEncoding for (RInverse, E)
where
E::Output: ReductionEncoding,
{
type Output = <<E as ReductionEncoding>::Output as ReductionEncoding>::Output;
}
impl ProductEncoding for (RR, Unencoded) {
type Output = <(Unencoded, RR) as ProductEncoding>::Output;
}
impl ProductEncoding for (RR, RInverse) {
type Output = <(RInverse, RR) as ProductEncoding>::Output;
}
#[derive(Clone, Copy)]
pub struct Elem<M> {
pub limbs: [Limb; LIMB_LENGTH],
pub m: PhantomData<M>,
}
impl<M> Elem<M> {
pub fn zero() -> Self {
Self {
limbs: [0; LIMB_LENGTH],
m: PhantomData,
}
}
pub fn is_zero(&self) -> bool {
norop_limbs_equal_with(&self.limbs, &[0; LIMB_LENGTH])
}
pub fn is_equal(&self, other: &Elem<M>) -> bool {
norop_limbs_equal_with(&self.limbs, &other.limbs)
}
}
pub fn elem_mul<EA: Encoding, EB: Encoding>(
a: &Elem<EA>,
b: &Elem<EB>,
) -> Elem<<(EA, EB) as ProductEncoding>::Output>
where
(EA, EB): ProductEncoding,
{
Elem {
limbs: mont_pro(&a.limbs, &b.limbs),
m: PhantomData,
}
}
pub fn elem_add(a: &Elem<R>, b: &Elem<R>) -> Elem<R> {
Elem {
limbs: add_mod(&a.limbs, &b.limbs),
m: PhantomData,
}
}
pub fn elem_inv_sqr_to_mont(a: &Elem<R>) -> Elem<R> {
assert!(!norop_limbs_equal_with(&a.limbs, &[0; LIMB_LENGTH]));
Elem {
limbs: inv_sqr(&a.limbs),
m: PhantomData,
}
}
pub fn elem_to_unencoded(a: &Elem<R>) -> Elem<Unencoded> {
Elem {
limbs: mont_pro(&a.limbs, &ONE),
m: PhantomData,
}
}
pub fn elem_reduced_to_scalar(e: &Elem<Unencoded>) -> Scalar {
if norop_limbs_less_than(&e.limbs, &CURVE_PARAMS.n) {
Scalar {
limbs: e.limbs,
m: PhantomData,
}
} else {
Scalar {
limbs: scalar_sub_mod(&e.limbs, &CURVE_PARAMS.n),
m: PhantomData,
}
}
}
pub fn scalar_to_elem(e: &Scalar) -> Elem<Unencoded> {
Elem {
limbs: e.limbs,
m: PhantomData,
}
}
pub fn point_x(p: &[Limb; LIMB_LENGTH * 3]) -> Elem<R> {
let mut r = Elem::zero();
r.limbs.copy_from_slice(&p[..LIMB_LENGTH]);
r
}
pub fn point_y(p: &[Limb; LIMB_LENGTH * 3]) -> Elem<R> {
let mut r = Elem::zero();
r.limbs.copy_from_slice(&p[LIMB_LENGTH..LIMB_LENGTH * 2]);
r
}
pub fn point_z(p: &[Limb; LIMB_LENGTH * 3]) -> Elem<R> {
let mut r = Elem::zero();
r.limbs.copy_from_slice(&p[LIMB_LENGTH * 2..]);
r
}
pub type Scalar<N = Unencoded> = Elem<N>;
pub fn scalar_inv_to_mont(a: &Scalar) -> Scalar<R> {
assert!(!norop_limbs_equal_with(&a.limbs, &[0; LIMB_LENGTH]));
Scalar {
limbs: scalar_inv(&a.limbs),
m: PhantomData,
}
}
pub fn scalar_to_unencoded(a: &Scalar<R>) -> Scalar {
Scalar {
limbs: scalar_mont_pro(&a.limbs, &ONE),
m: PhantomData,
}
}
pub fn scalar_mul<EA: Encoding, EB: Encoding>(
a: &Scalar<EA>,
b: &Scalar<EB>,
) -> Scalar<<(EA, EB) as ProductEncoding>::Output>
where
(EA, EB): ProductEncoding,
{
Scalar {
limbs: scalar_mont_pro(&a.limbs, &b.limbs),
m: PhantomData,
}
}
pub fn scalar_add(a: &Scalar, b: &Scalar) -> Scalar {
Scalar {
limbs: scalar_add_mod(&a.limbs, &b.limbs),
m: PhantomData,
}
}
pub fn scalar_sub(a: &Scalar, b: &Scalar) -> Scalar {
Scalar {
limbs: scalar_sub_mod(&a.limbs, &b.limbs),
m: PhantomData,
}
}
fn scalar_g(g_scalar: &Scalar) -> [Limb; LIMB_LENGTH * 3] {
base_point_mul(&g_scalar.limbs)
}
fn scalar_p(p_scalar: &Scalar, pk: &PublicKey) -> [Limb; LIMB_LENGTH * 3] {
let point = pk.to_point();
point_mul(&point, &p_scalar.limbs)
}
pub fn twin_mul(g_scalar: &Scalar, p_scalar: &Scalar, pk: &PublicKey) -> [Limb; LIMB_LENGTH * 3] {
let g_point = scalar_g(g_scalar);
let p_point = scalar_p(p_scalar, pk);
point_add(&g_point, &p_point)
}