use crate::rand::fill_bytes;
use core::fmt::{Debug, Formatter};
use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
pub fn generate_x25519_keypair() -> ([u8; 32], [u8; 32]) {
generate_x25519_keypair_no_dalek()
}
pub fn x25519(private_key: [u8; 32], peer_public_key: [u8; 32]) -> [u8; 32] {
x25519_no_dalek(private_key, peer_public_key)
}
pub fn generate_x25519_keypair_no_dalek() -> ([u8; 32], [u8; 32]) {
let mut private_key_random_bytes = [0u8; 32];
fill_bytes(&mut private_key_random_bytes);
let clamped_private_key = clamp_integer(private_key_random_bytes);
let basepoint = MontgomeryPoint(X25519_BASEPOINT_BYTES);
let s = Scalar::new(clamped_private_key);
let public_key_point = &basepoint * &s;
(clamped_private_key, public_key_point.to_bytes())
}
pub fn x25519_no_dalek(private_key: [u8; 32], peer_public_key: [u8; 32]) -> [u8; 32] {
let clamped_private_key = clamp_integer(private_key);
let peer_public_point = MontgomeryPoint(peer_public_key);
let s = Scalar::new(clamped_private_key);
let shared_secret_point = &peer_public_point * &s;
shared_secret_point.to_bytes()
}
pub(crate) const APLUS2_OVER_FOUR: FieldElement51 =
FieldElement51::from_limbs_const([121666, 0, 0, 0, 0]);
pub(crate) const SQRT_M1: FieldElement51 = FieldElement51::from_limbs_const([
1718705420411056,
234908883556509,
2233514472574048,
2117202627021982,
765476049583133,
]);
pub const X25519_BASEPOINT_BYTES: [u8; 32] = [
9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0,
];
#[derive(Copy, Clone)]
pub struct FieldElement51(pub [u64; 5]);
impl Debug for FieldElement51 {
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
write!(f, "FieldElement51({:?})", &self.0[..])
}
}
impl Eq for FieldElement51 {}
impl PartialEq for FieldElement51 {
fn eq(&self, other: &FieldElement51) -> bool {
self.ct_eq(other).into()
}
}
impl ConstantTimeEq for FieldElement51 {
fn ct_eq(&self, other: &FieldElement51) -> Choice {
self.as_bytes().ct_eq(&other.as_bytes())
}
}
impl ConditionallySelectable for FieldElement51 {
fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
FieldElement51([
u64::conditional_select(&a.0[0], &b.0[0], choice),
u64::conditional_select(&a.0[1], &b.0[1], choice),
u64::conditional_select(&a.0[2], &b.0[2], choice),
u64::conditional_select(&a.0[3], &b.0[3], choice),
u64::conditional_select(&a.0[4], &b.0[4], choice),
])
}
}
impl<'b> AddAssign<&'b FieldElement51> for FieldElement51 {
fn add_assign(&mut self, rhs: &'b FieldElement51) {
for i in 0..5 {
self.0[i] += rhs.0[i];
}
}
}
impl<'a, 'b> Add<&'b FieldElement51> for &'a FieldElement51 {
type Output = FieldElement51;
fn add(self, rhs: &'b FieldElement51) -> FieldElement51 {
let mut output = *self;
output += rhs;
output
}
}
impl<'b> SubAssign<&'b FieldElement51> for FieldElement51 {
fn sub_assign(&mut self, rhs: &'b FieldElement51) {
let result = (self as &FieldElement51) - rhs; self.0 = result.0;
}
}
impl<'a, 'b> Sub<&'b FieldElement51> for &'a FieldElement51 {
type Output = FieldElement51;
fn sub(self, rhs: &'b FieldElement51) -> FieldElement51 {
FieldElement51::reduce_after_sub([
(self.0[0] + 36028797018963664u64) - rhs.0[0], (self.0[1] + 36028797018963952u64) - rhs.0[1], (self.0[2] + 36028797018963952u64) - rhs.0[2],
(self.0[3] + 36028797018963952u64) - rhs.0[3],
(self.0[4] + 36028797018963952u64) - rhs.0[4],
])
}
}
impl<'b> MulAssign<&'b FieldElement51> for FieldElement51 {
fn mul_assign(&mut self, rhs: &'b FieldElement51) {
let result = (self as &FieldElement51) * rhs; self.0 = result.0;
}
}
impl<'a, 'b> Mul<&'b FieldElement51> for &'a FieldElement51 {
type Output = FieldElement51;
#[rustfmt::skip]
fn mul(self, rhs: &'b FieldElement51) -> FieldElement51 {
#[inline(always)]
fn m(x: u64, y: u64) -> u128 { (x as u128) * (y as u128) }
let a: &[u64; 5] = &self.0;
let b: &[u64; 5] = &rhs.0;
let b1_19 = b[1] * 19; let b2_19 = b[2] * 19; let b3_19 = b[3] * 19; let b4_19 = b[4] * 19;
let mut c0: u128 = m(a[0], b[0]);
c0 += m(a[4], b1_19); c0 += m(a[3], b2_19); c0 += m(a[2], b3_19); c0 += m(a[1], b4_19);
let mut c1: u128 = m(a[1], b[0]);
c1 += m(a[0], b[1]); c1 += m(a[4], b2_19); c1 += m(a[3], b3_19); c1 += m(a[2], b4_19);
let mut c2: u128 = m(a[2], b[0]);
c2 += m(a[1], b[1]); c2 += m(a[0], b[2]); c2 += m(a[4], b3_19); c2 += m(a[3], b4_19);
let mut c3: u128 = m(a[3], b[0]);
c3 += m(a[2], b[1]); c3 += m(a[1], b[2]); c3 += m(a[0], b[3]); c3 += m(a[4], b4_19);
let mut c4: u128 = m(a[4], b[0]);
c4 += m(a[3], b[1]); c4 += m(a[2], b[2]); c4 += m(a[1], b[3]); c4 += m(a[0], b[4]);
const LOW_51_BIT_MASK: u64 = (1u64 << 51) - 1;
let mut out = [0u64; 5];
c1 += (c0 >> 51) as u64 as u128; out[0] = (c0 as u64) & LOW_51_BIT_MASK;
c2 += (c1 >> 51) as u64 as u128; out[1] = (c1 as u64) & LOW_51_BIT_MASK;
c3 += (c2 >> 51) as u64 as u128; out[2] = (c2 as u64) & LOW_51_BIT_MASK;
c4 += (c3 >> 51) as u64 as u128; out[3] = (c3 as u64) & LOW_51_BIT_MASK;
let carry: u64 = (c4 >> 51) as u64;
out[4] = (c4 as u64) & LOW_51_BIT_MASK;
out[0] += carry * 19; out[1] += out[0] >> 51; out[0] &= LOW_51_BIT_MASK;
FieldElement51(out)
}
}
impl<'a> Neg for &'a FieldElement51 {
type Output = FieldElement51;
fn neg(self) -> FieldElement51 {
let mut output = FieldElement51::ZERO; output -= self; output
}
}
impl FieldElement51 {
pub(crate) const fn from_limbs_const(limbs: [u64; 5]) -> Self {
FieldElement51(limbs)
}
pub fn from_limbs(limbs: [u64; 5]) -> Self {
FieldElement51(limbs)
}
pub const ZERO: FieldElement51 = FieldElement51::from_limbs_const([0, 0, 0, 0, 0]);
pub const ONE: FieldElement51 = FieldElement51::from_limbs_const([1, 0, 0, 0, 0]);
pub const MINUS_ONE: FieldElement51 = FieldElement51::from_limbs_const([
2251799813685228, 2251799813685247, 2251799813685247, 2251799813685247, 2251799813685247,
]);
#[rustfmt::skip]
pub fn pow_p58(&self) -> FieldElement51 {
let (t19, _) = self.pow22501(); let t20 = t19.pow2k(2); let t21 = self * &t20; t21
}
pub fn sqrt_ratio_i(u: &FieldElement51, v: &FieldElement51) -> (Choice, FieldElement51) {
let v3 = &v.square() * v;
let v7 = &v3.square() * v;
let mut r = &(u * &v3) * &(u * &v7).pow_p58();
let check = v * &r.square();
let i = &SQRT_M1;
let correct_sign_sqrt = check.ct_eq(u);
let neg_u = -u; let flipped_sign_sqrt = check.ct_eq(&neg_u);
let neg_u_times_i = &neg_u * i;
let flipped_sign_sqrt_i = check.ct_eq(&neg_u_times_i);
let r_prime = i * &r; r.conditional_assign(&r_prime, flipped_sign_sqrt | flipped_sign_sqrt_i);
let r_is_negative = r.is_negative();
r.conditional_negate(r_is_negative);
let was_nonzero_square = correct_sign_sqrt | flipped_sign_sqrt;
let u_is_zero = u.is_zero();
let r_if_u_is_zero = FieldElement51::ZERO; r.conditional_assign(&r_if_u_is_zero, u_is_zero);
(was_nonzero_square, r)
}
pub fn invsqrt(&self) -> (Choice, FieldElement51) {
FieldElement51::sqrt_ratio_i(&FieldElement51::ONE, self)
}
pub fn square2(&self) -> FieldElement51 {
let mut square = self.pow2k(1); for i in 0..5 {
square.0[i] *= 2; }
square
}
#[inline(always)]
fn reduce_after_sub(mut limbs: [u64; 5]) -> FieldElement51 {
const LOW_51_BIT_MASK: u64 = (1u64 << 51) - 1;
let c0 = limbs[0] >> 51; let c1 = limbs[1] >> 51;
let c2 = limbs[2] >> 51; let c3 = limbs[3] >> 51;
let c4 = limbs[4] >> 51;
limbs[0] &= LOW_51_BIT_MASK; limbs[1] &= LOW_51_BIT_MASK;
limbs[2] &= LOW_51_BIT_MASK; limbs[3] &= LOW_51_BIT_MASK;
limbs[4] &= LOW_51_BIT_MASK;
limbs[0] += c4 * 19; limbs[1] += c0;
limbs[2] += c1; limbs[3] += c2;
limbs[4] += c3;
FieldElement51(limbs)
}
#[rustfmt::skip]
pub fn from_bytes(bytes: &[u8; 32]) -> FieldElement51 {
let load8 = |input: &[u8]| -> u64 {
(input[0] as u64)
| ((input[1] as u64) << 8) | ((input[2] as u64) << 16) | ((input[3] as u64) << 24)
| ((input[4] as u64) << 32) | ((input[5] as u64) << 40) | ((input[6] as u64) << 48)
| ((input[7] as u64) << 56)
};
const LOW_51_BIT_MASK: u64 = (1u64 << 51) - 1;
FieldElement51(
[ load8(&bytes[ 0..]) & LOW_51_BIT_MASK
, (load8(&bytes[ 6..]) >> 3) & LOW_51_BIT_MASK
, (load8(&bytes[12..]) >> 6) & LOW_51_BIT_MASK
, (load8(&bytes[19..]) >> 1) & LOW_51_BIT_MASK
, (load8(&bytes[24..]) >> 12) & LOW_51_BIT_MASK
])
}
#[inline(always)]
pub(crate) fn reduce(mut limbs: [u64; 5]) -> FieldElement51 { const LOW_51_BIT_MASK: u64 = (1u64 << 51) - 1;
let c0 = limbs[0] >> 51;
let c1 = limbs[1] >> 51;
let c2 = limbs[2] >> 51;
let c3 = limbs[3] >> 51;
let c4 = limbs[4] >> 51;
limbs[0] &= LOW_51_BIT_MASK;
limbs[1] &= LOW_51_BIT_MASK;
limbs[2] &= LOW_51_BIT_MASK;
limbs[3] &= LOW_51_BIT_MASK;
limbs[4] &= LOW_51_BIT_MASK;
limbs[0] += c4 * 19; limbs[1] += c0;
limbs[2] += c1;
limbs[3] += c2;
limbs[4] += c3;
FieldElement51(limbs)
}
#[rustfmt::skip]
pub fn as_bytes(&self) -> [u8; 32] {
let reduced_self = FieldElement51::reduce(self.0);
let mut limbs = reduced_self.0;
let mut q = (limbs[0] + 19) >> 51;
q = (limbs[1] + q) >> 51;
q = (limbs[2] + q) >> 51;
q = (limbs[3] + q) >> 51;
q = (limbs[4] + q) >> 51;
limbs[0] += 19 * q;
const LOW_51_BIT_MASK: u64 = (1u64 << 51) - 1;
limbs[1] += limbs[0] >> 51; limbs[0] &= LOW_51_BIT_MASK;
limbs[2] += limbs[1] >> 51; limbs[1] &= LOW_51_BIT_MASK;
limbs[3] += limbs[2] >> 51; limbs[2] &= LOW_51_BIT_MASK;
limbs[4] += limbs[3] >> 51; limbs[3] &= LOW_51_BIT_MASK;
limbs[4] &= LOW_51_BIT_MASK;
let mut s = [0u8;32];
s[ 0] = (limbs[0] ) as u8; s[ 1] = (limbs[0] >> 8) as u8;
s[ 2] = (limbs[0] >> 16) as u8; s[ 3] = (limbs[0] >> 24) as u8;
s[ 4] = (limbs[0] >> 32) as u8; s[ 5] = (limbs[0] >> 40) as u8;
s[ 6] = ((limbs[0] >> 48) | (limbs[1] << 3)) as u8;
s[ 7] = (limbs[1] >> 5) as u8; s[ 8] = (limbs[1] >> 13) as u8;
s[ 9] = (limbs[1] >> 21) as u8; s[10] = (limbs[1] >> 29) as u8;
s[11] = (limbs[1] >> 37) as u8;
s[12] = ((limbs[1] >> 45) | (limbs[2] << 6)) as u8;
s[13] = (limbs[2] >> 2) as u8; s[14] = (limbs[2] >> 10) as u8;
s[15] = (limbs[2] >> 18) as u8; s[16] = (limbs[2] >> 26) as u8;
s[17] = (limbs[2] >> 34) as u8; s[18] = (limbs[2] >> 42) as u8;
s[19] = ((limbs[2] >> 50) | (limbs[3] << 1)) as u8;
s[20] = (limbs[3] >> 7) as u8; s[21] = (limbs[3] >> 15) as u8;
s[22] = (limbs[3] >> 23) as u8; s[23] = (limbs[3] >> 31) as u8;
s[24] = (limbs[3] >> 39) as u8;
s[25] = ((limbs[3] >> 47) | (limbs[4] << 4)) as u8;
s[26] = (limbs[4] >> 4) as u8; s[27] = (limbs[4] >> 12) as u8;
s[28] = (limbs[4] >> 20) as u8; s[29] = (limbs[4] >> 28) as u8;
s[30] = (limbs[4] >> 36) as u8; s[31] = (limbs[4] >> 44) as u8;
s
}
#[rustfmt::skip]
pub fn pow2k(&self, mut k: u32) -> FieldElement51 {
debug_assert!(k > 0);
#[inline(always)]
fn m(x: u64, y: u64) -> u128 { (x as u128) * (y as u128) }
let mut a: [u64; 5] = self.0;
loop {
let a3_19 = 19 * a[3]; let a4_19 = 19 * a[4];
let c0: u128 = m(a[0], a[0]) + 2*( m(a[1], a4_19) + m(a[2], a3_19) );
let mut c1: u128 = m(a[3], a3_19) + 2*( m(a[0], a[1]) + m(a[2], a4_19) );
let mut c2: u128 = m(a[1], a[1]) + 2*( m(a[0], a[2]) + m(a[4], a3_19) );
let mut c3: u128 = m(a[4], a4_19) + 2*( m(a[0], a[3]) + m(a[1], a[2]) );
let mut c4: u128 = m(a[2], a[2]) + 2*( m(a[0], a[4]) + m(a[1], a[3]) );
const LOW_51_BIT_MASK: u64 = (1u64 << 51) - 1;
c1 += (c0 >> 51) as u64 as u128; a[0] = (c0 as u64) & LOW_51_BIT_MASK;
c2 += (c1 >> 51) as u64 as u128; a[1] = (c1 as u64) & LOW_51_BIT_MASK;
c3 += (c2 >> 51) as u64 as u128; a[2] = (c2 as u64) & LOW_51_BIT_MASK;
c4 += (c3 >> 51) as u64 as u128; a[3] = (c3 as u64) & LOW_51_BIT_MASK;
let carry: u64 = (c4 >> 51) as u64;
a[4] = (c4 as u64) & LOW_51_BIT_MASK;
a[0] += carry * 19; a[1] += a[0] >> 51; a[0] &= LOW_51_BIT_MASK;
k -= 1; if k == 0 { break; }
}
FieldElement51(a)
}
pub fn square(&self) -> FieldElement51 {
self.pow2k(1)
}
pub fn is_negative(&self) -> Choice {
(self.as_bytes()[0] & 1).into()
}
pub fn is_zero(&self) -> Choice {
let zero_bytes = [0u8; 32];
self.as_bytes().ct_eq(&zero_bytes)
}
#[rustfmt::skip]
fn pow22501(&self) -> (FieldElement51, FieldElement51) {
let t0 = self.square();
let t1 = t0.square().square();
let t2 = self * &t1;
let t3 = &t0 * &t2;
let t4 = t3.square();
let t5 = &t2 * &t4;
let t6 = t5.pow2k(5);
let t7 = &t6 * &t5;
let t8 = t7.pow2k(10);
let t9 = &t8 * &t7;
let t10 = t9.pow2k(20);
let t11 = &t10 * &t9;
let t12 = t11.pow2k(10);
let t13 = &t12 * &t7;
let t14 = t13.pow2k(50);
let t15 = &t14 * &t13;
let t16 = t15.pow2k(100);
let t17 = &t16 * &t15;
let t18 = t17.pow2k(50);
let t19 = &t18 * &t13;
(t19, t3)
}
#[rustfmt::skip]
pub fn invert(&self) -> FieldElement51 {
let (t19, t3) = self.pow22501();
let t20 = t19.pow2k(5);
let t21 = &t20 * &t3;
t21
}
}
impl Add<FieldElement51> for FieldElement51 { type Output = Self; fn add(self, rhs: Self) -> Self { &self + &rhs } }
impl Add<&FieldElement51> for FieldElement51 { type Output = Self; fn add(self, rhs: &Self) -> Self { &self + rhs } }
impl Add<FieldElement51> for &FieldElement51 { type Output = FieldElement51; fn add(self, rhs: FieldElement51) -> FieldElement51 { self + &rhs } }
impl Sub<FieldElement51> for FieldElement51 { type Output = Self; fn sub(self, rhs: Self) -> Self { &self - &rhs } }
impl Sub<&FieldElement51> for FieldElement51 { type Output = Self; fn sub(self, rhs: &Self) -> Self { &self - rhs } }
impl Sub<FieldElement51> for &FieldElement51 { type Output = FieldElement51; fn sub(self, rhs: FieldElement51) -> FieldElement51 { self - &rhs } }
impl Mul<FieldElement51> for FieldElement51 { type Output = Self; fn mul(self, rhs: Self) -> Self { &self * &rhs } }
impl Mul<&FieldElement51> for FieldElement51 { type Output = Self; fn mul(self, rhs: &Self) -> Self { &self * rhs } }
impl Mul<FieldElement51> for &FieldElement51 { type Output = FieldElement51; fn mul(self, rhs: FieldElement51) -> FieldElement51 { self * &rhs } }
impl Neg for FieldElement51 { type Output = Self; fn neg(self) -> Self { -&self } }
#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
pub struct Scalar {
pub bytes: [u8; 32],
}
impl Scalar {
pub fn new(bytes: [u8; 32]) -> Self {
Scalar { bytes }
}
pub(crate) fn bits_le(&self) -> impl DoubleEndedIterator<Item = bool> + Clone + '_ {
(0..256).map(move |i| (((self.bytes[i >> 3] >> (i & 7)) & 1u8) == 1u8))
}
pub const fn to_bytes(&self) -> [u8; 32] {
self.bytes
}
}
#[must_use]
pub const fn clamp_integer(mut bytes: [u8; 32]) -> [u8; 32] {
bytes[0] &= 0b1111_1000;
bytes[31] &= 0b0111_1111;
bytes[31] |= 0b0100_0000;
bytes
}
#[derive(Copy, Clone, Debug, Default)]
pub struct MontgomeryPoint(pub [u8; 32]);
impl ConstantTimeEq for MontgomeryPoint {
fn ct_eq(&self, other: &MontgomeryPoint) -> Choice {
let self_fe = FieldElement51::from_bytes(&self.0);
let other_fe = FieldElement51::from_bytes(&other.0);
self_fe.ct_eq(&other_fe)
}
}
impl ConditionallySelectable for MontgomeryPoint {
fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
let mut new_bytes = [0u8; 32];
for i in 0..32 {
new_bytes[i] = u8::conditional_select(&a.0[i], &b.0[i], choice);
}
Self(new_bytes)
}
}
impl PartialEq for MontgomeryPoint {
fn eq(&self, other: &MontgomeryPoint) -> bool {
self.ct_eq(other).into()
}
}
impl Eq for MontgomeryPoint {}
impl MontgomeryPoint {
pub fn identity() -> MontgomeryPoint {
MontgomeryPoint([0u8; 32])
}
pub fn mul_bits_be(&self, bits: impl Iterator<Item = bool>) -> MontgomeryPoint {
let affine_u = FieldElement51::from_bytes(&self.0);
let mut x0 = ProjectivePoint::identity();
let mut x1 = ProjectivePoint {
U: affine_u,
W: FieldElement51::ONE,
};
let mut prev_bit = false;
for cur_bit in bits {
let choice: u8 = (prev_bit ^ cur_bit) as u8;
debug_assert!(choice == 0 || choice == 1);
ProjectivePoint::conditional_swap(&mut x0, &mut x1, choice.into());
differential_add_and_double(&mut x0, &mut x1, &affine_u);
prev_bit = cur_bit;
}
ProjectivePoint::conditional_swap(&mut x0, &mut x1, Choice::from(prev_bit as u8));
x0.as_affine()
}
pub const fn as_bytes(&self) -> &[u8; 32] {
&self.0
}
pub const fn to_bytes(&self) -> [u8; 32] {
self.0
}
}
#[allow(non_snake_case)]
#[derive(Copy, Clone, Debug)]
struct ProjectivePoint {
pub U: FieldElement51, pub W: FieldElement51, }
impl ProjectivePoint {
pub fn identity() -> ProjectivePoint {
ProjectivePoint {
U: FieldElement51::ONE,
W: FieldElement51::ZERO,
}
}
pub fn as_affine(&self) -> MontgomeryPoint {
let u_inv_w = self.W.invert(); let u = &self.U * &u_inv_w; MontgomeryPoint(u.as_bytes()) }
}
impl ConditionallySelectable for ProjectivePoint {
fn conditional_select(
a: &ProjectivePoint,
b: &ProjectivePoint,
choice: Choice,
) -> ProjectivePoint {
ProjectivePoint {
U: FieldElement51::conditional_select(&a.U, &b.U, choice),
W: FieldElement51::conditional_select(&a.W, &b.W, choice),
}
}
}
#[allow(non_snake_case)]
#[rustfmt::skip]
fn differential_add_and_double(
P: &mut ProjectivePoint,
Q: &mut ProjectivePoint,
affine_PmQ: &FieldElement51,
) {
let t0 = &P.U + &P.W;
let t1 = &P.U - &P.W;
let t2 = &Q.U + &Q.W;
let t3 = &Q.U - &Q.W;
let t4 = t0.square(); let t5 = t1.square();
let t6 = &t4 - &t5;
let t7 = &t0 * &t3; let t8 = &t1 * &t2;
let t9 = &t7 + &t8; let t10 = &t7 - &t8;
let t11 = t9.square(); let t12 = t10.square();
let t13 = &APLUS2_OVER_FOUR * &t6;
let t14 = &t4 * &t5; let t15 = &t13 + &t5;
let t16 = &t6 * &t15;
let t17 = affine_PmQ * &t12; let t18 = t11;
P.U = t14; P.W = t16; Q.U = t18; Q.W = t17; }
impl<'a, 'b> Mul<&'b Scalar> for &'a MontgomeryPoint {
type Output = MontgomeryPoint;
fn mul(self, scalar: &'b Scalar) -> MontgomeryPoint {
self.mul_bits_be(scalar.bits_le().rev().skip(1))
}
}
impl Mul<Scalar> for MontgomeryPoint {
type Output = MontgomeryPoint;
fn mul(self, scalar: Scalar) -> MontgomeryPoint {
&self * &scalar
}
}
impl Mul<&Scalar> for MontgomeryPoint {
type Output = MontgomeryPoint;
fn mul(self, scalar: &Scalar) -> MontgomeryPoint {
&self * scalar
}
}
impl Mul<MontgomeryPoint> for Scalar {
type Output = MontgomeryPoint;
fn mul(self, point: MontgomeryPoint) -> MontgomeryPoint {
&point * &self
}
}
impl Mul<&MontgomeryPoint> for Scalar {
type Output = MontgomeryPoint;
fn mul(self, point: &MontgomeryPoint) -> MontgomeryPoint {
point * &self
}
}
impl MulAssign<&Scalar> for MontgomeryPoint {
fn mul_assign(&mut self, scalar: &Scalar) {
*self = (self as &MontgomeryPoint) * scalar;
}
}
impl MulAssign<Scalar> for MontgomeryPoint {
fn mul_assign(&mut self, scalar: Scalar) {
*self = (self as &MontgomeryPoint) * &scalar;
}
}
use core::cmp;
use core::ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Not};
use core::option::Option;
#[derive(Copy, Clone, Debug)]
pub struct Choice(u8);
impl Choice {
#[inline]
pub fn unwrap_u8(&self) -> u8 {
self.0
}
}
impl From<Choice> for bool {
#[inline]
fn from(source: Choice) -> bool {
debug_assert!((source.0 == 0u8) | (source.0 == 1u8));
source.0 != 0
}
}
impl BitAnd for Choice {
type Output = Choice;
#[inline]
fn bitand(self, rhs: Choice) -> Choice {
(self.0 & rhs.0).into()
}
}
impl BitAndAssign for Choice {
#[inline]
fn bitand_assign(&mut self, rhs: Choice) {
*self = *self & rhs;
}
}
impl BitOr for Choice {
type Output = Choice;
#[inline]
fn bitor(self, rhs: Choice) -> Choice {
(self.0 | rhs.0).into()
}
}
impl BitOrAssign for Choice {
#[inline]
fn bitor_assign(&mut self, rhs: Choice) {
*self = *self | rhs;
}
}
impl BitXor for Choice {
type Output = Choice;
#[inline]
fn bitxor(self, rhs: Choice) -> Choice {
(self.0 ^ rhs.0).into()
}
}
impl BitXorAssign for Choice {
#[inline]
fn bitxor_assign(&mut self, rhs: Choice) {
*self = *self ^ rhs;
}
}
impl Not for Choice {
type Output = Choice;
#[inline]
fn not(self) -> Choice {
(1u8 & (!self.0)).into()
}
}
#[inline(never)]
fn black_box<T: Copy>(input: T) -> T {
unsafe {
core::ptr::read_volatile(&input)
}
}
impl From<u8> for Choice {
#[inline]
fn from(input: u8) -> Choice {
debug_assert!((input == 0u8) | (input == 1u8));
Choice(black_box(input))
}
}
#[allow(unused_attributes)] pub trait ConstantTimeEq {
#[inline]
#[allow(unused_attributes)]
fn ct_eq(&self, other: &Self) -> Choice;
#[inline]
fn ct_ne(&self, other: &Self) -> Choice {
!self.ct_eq(other)
}
}
impl<T: ConstantTimeEq> ConstantTimeEq for [T] {
#[inline]
fn ct_eq(&self, _rhs: &[T]) -> Choice {
let len = self.len();
if len != _rhs.len() {
return Choice::from(0);
}
let mut x = 1u8;
for (ai, bi) in self.iter().zip(_rhs.iter()) {
x &= ai.ct_eq(bi).unwrap_u8();
}
x.into()
}
}
impl ConstantTimeEq for Choice {
#[inline]
fn ct_eq(&self, rhs: &Choice) -> Choice {
!(*self ^ *rhs)
}
}
macro_rules! generate_integer_equal {
($t_u:ty, $t_i:ty, $bit_width:expr) => {
impl ConstantTimeEq for $t_u {
#[inline]
fn ct_eq(&self, other: &$t_u) -> Choice {
let x: $t_u = self ^ other;
let y: $t_u = (x | x.wrapping_neg()) >> ($bit_width - 1);
((y ^ (1 as $t_u)) as u8).into()
}
}
impl ConstantTimeEq for $t_i {
#[inline]
fn ct_eq(&self, other: &$t_i) -> Choice {
(*self as $t_u).ct_eq(&(*other as $t_u))
}
}
};
}
generate_integer_equal!(u8, i8, 8);
generate_integer_equal!(u16, i16, 16);
generate_integer_equal!(u32, i32, 32);
generate_integer_equal!(u64, i64, 64);
generate_integer_equal!(usize, isize, ::core::mem::size_of::<usize>() * 8);
impl ConstantTimeEq for cmp::Ordering {
#[inline]
fn ct_eq(&self, other: &Self) -> Choice {
(*self as i8).ct_eq(&(*other as i8))
}
}
#[allow(unused_attributes)] pub trait ConditionallySelectable: Copy {
#[inline]
#[allow(unused_attributes)]
fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self;
#[inline]
fn conditional_assign(&mut self, other: &Self, choice: Choice) {
*self = Self::conditional_select(self, other, choice);
}
#[inline]
fn conditional_swap(a: &mut Self, b: &mut Self, choice: Choice) {
let t: Self = *a;
a.conditional_assign(&b, choice);
b.conditional_assign(&t, choice);
}
}
macro_rules! to_signed_int {
(u8) => {
i8
};
(u16) => {
i16
};
(u32) => {
i32
};
(u64) => {
i64
};
(u128) => {
i128
};
(i8) => {
i8
};
(i16) => {
i16
};
(i32) => {
i32
};
(i64) => {
i64
};
(i128) => {
i128
};
}
macro_rules! generate_integer_conditional_select {
($($t:tt)*) => ($(
impl ConditionallySelectable for $t {
#[inline]
fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
let mask = -(choice.unwrap_u8() as to_signed_int!($t)) as $t;
a ^ (mask & (a ^ b))
}
#[inline]
fn conditional_assign(&mut self, other: &Self, choice: Choice) {
let mask = -(choice.unwrap_u8() as to_signed_int!($t)) as $t;
*self ^= mask & (*self ^ *other);
}
#[inline]
fn conditional_swap(a: &mut Self, b: &mut Self, choice: Choice) {
let mask = -(choice.unwrap_u8() as to_signed_int!($t)) as $t;
let t = mask & (*a ^ *b);
*a ^= t;
*b ^= t;
}
}
)*)
}
generate_integer_conditional_select!( u8 i8);
generate_integer_conditional_select!( u16 i16);
generate_integer_conditional_select!( u32 i32);
generate_integer_conditional_select!( u64 i64);
impl ConditionallySelectable for cmp::Ordering {
fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
let a = *a as i8;
let b = *b as i8;
let ret = i8::conditional_select(&a, &b, choice);
unsafe { *((&ret as *const _) as *const cmp::Ordering) }
}
}
impl ConditionallySelectable for Choice {
#[inline]
fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
Choice(u8::conditional_select(&a.0, &b.0, choice))
}
}
#[allow(unused_attributes)] pub trait ConditionallyNegatable {
#[inline]
#[allow(unused_attributes)]
fn conditional_negate(&mut self, choice: Choice);
}
impl<T> ConditionallyNegatable for T
where
T: ConditionallySelectable,
for<'a> &'a T: Neg<Output = T>,
{
#[inline]
fn conditional_negate(&mut self, choice: Choice) {
let self_neg: T = -(self as &T);
self.conditional_assign(&self_neg, choice);
}
}
#[derive(Clone, Copy, Debug)]
pub struct CtOption<T> {
value: T,
is_some: Choice,
}
impl<T> From<CtOption<T>> for Option<T> {
fn from(source: CtOption<T>) -> Option<T> {
if source.is_some().unwrap_u8() == 1u8 {
Option::Some(source.value)
} else {
None
}
}
}
impl<T> CtOption<T> {
#[inline]
pub fn new(value: T, is_some: Choice) -> CtOption<T> {
CtOption {
value: value,
is_some: is_some,
}
}
pub fn expect(self, msg: &str) -> T {
assert_eq!(self.is_some.unwrap_u8(), 1, "{}", msg);
self.value
}
#[inline]
pub fn unwrap(self) -> T {
assert_eq!(self.is_some.unwrap_u8(), 1);
self.value
}
#[inline]
pub fn unwrap_or(self, def: T) -> T
where
T: ConditionallySelectable,
{
T::conditional_select(&def, &self.value, self.is_some)
}
#[inline]
pub fn unwrap_or_else<F>(self, f: F) -> T
where
T: ConditionallySelectable,
F: FnOnce() -> T,
{
T::conditional_select(&f(), &self.value, self.is_some)
}
#[inline]
pub fn is_some(&self) -> Choice {
self.is_some
}
#[inline]
pub fn is_none(&self) -> Choice {
!self.is_some
}
#[inline]
pub fn map<U, F>(self, f: F) -> CtOption<U>
where
T: Default + ConditionallySelectable,
F: FnOnce(T) -> U,
{
CtOption::new(
f(T::conditional_select(
&T::default(),
&self.value,
self.is_some,
)),
self.is_some,
)
}
#[inline]
pub fn and_then<U, F>(self, f: F) -> CtOption<U>
where
T: Default + ConditionallySelectable,
F: FnOnce(T) -> CtOption<U>,
{
let mut tmp = f(T::conditional_select(
&T::default(),
&self.value,
self.is_some,
));
tmp.is_some &= self.is_some;
tmp
}
#[inline]
pub fn or_else<F>(self, f: F) -> CtOption<T>
where
T: ConditionallySelectable,
F: FnOnce() -> CtOption<T>,
{
let is_none = self.is_none();
let f = f();
Self::conditional_select(&self, &f, is_none)
}
pub fn into_option(self) -> Option<T> {
self.into()
}
}
impl<T: ConditionallySelectable> ConditionallySelectable for CtOption<T> {
fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
CtOption::new(
T::conditional_select(&a.value, &b.value, choice),
Choice::conditional_select(&a.is_some, &b.is_some, choice),
)
}
}
impl<T: ConstantTimeEq> ConstantTimeEq for CtOption<T> {
#[inline]
fn ct_eq(&self, rhs: &CtOption<T>) -> Choice {
let a = self.is_some();
let b = rhs.is_some();
(a & b & self.value.ct_eq(&rhs.value)) | (!a & !b)
}
}
pub trait ConstantTimeGreater {
fn ct_gt(&self, other: &Self) -> Choice;
}
macro_rules! generate_unsigned_integer_greater {
($t_u: ty, $bit_width: expr) => {
impl ConstantTimeGreater for $t_u {
#[inline]
fn ct_gt(&self, other: &$t_u) -> Choice {
let gtb = self & !other; let mut ltb = !self & other; let mut pow = 1;
while pow < $bit_width {
ltb |= ltb >> pow; pow += pow;
}
let mut bit = gtb & !ltb; let mut pow = 1;
while pow < $bit_width {
bit |= bit >> pow; pow += pow;
}
Choice::from((bit & 1) as u8)
}
}
};
}
generate_unsigned_integer_greater!(u8, 8);
generate_unsigned_integer_greater!(u16, 16);
generate_unsigned_integer_greater!(u32, 32);
generate_unsigned_integer_greater!(u64, 64);
impl ConstantTimeGreater for cmp::Ordering {
#[inline]
fn ct_gt(&self, other: &Self) -> Choice {
let a = (*self as i8) + 1;
let b = (*other as i8) + 1;
(a as u8).ct_gt(&(b as u8))
}
}
pub trait ConstantTimeLess: ConstantTimeEq + ConstantTimeGreater {
#[inline]
fn ct_lt(&self, other: &Self) -> Choice {
!self.ct_gt(other) & !self.ct_eq(other)
}
}
impl ConstantTimeLess for u8 {}
impl ConstantTimeLess for u16 {}
impl ConstantTimeLess for u32 {}
impl ConstantTimeLess for u64 {}
impl ConstantTimeLess for cmp::Ordering {
#[inline]
fn ct_lt(&self, other: &Self) -> Choice {
let a = (*self as i8) + 1;
let b = (*other as i8) + 1;
(a as u8).ct_lt(&(b as u8))
}
}
#[derive(Clone, Copy, Debug)]
pub struct BlackBox<T: Copy>(T);
impl<T: Copy> BlackBox<T> {
pub const fn new(value: T) -> Self {
Self(value)
}
pub fn get(self) -> T {
black_box(self.0)
}
}