use core::{
fmt,
hash::{Hash, Hasher},
};
use crate::{
SecretBytes,
auth::curve25519_edwards,
backend::curve25519::{FieldElement, clamp_secret_scalar},
traits::ct,
};
const POINT_LENGTH: usize = 32;
const RADIX_BITS: u32 = 51;
const MASK51: u64 = (1u64 << RADIX_BITS) - 1;
const BASEPOINT_BYTES: [u8; POINT_LENGTH] = {
let mut bytes = [0u8; POINT_LENGTH];
bytes[0] = 9;
bytes
};
const A24: FieldElement = FieldElement::from_small(121665);
define_unit_error! {
pub struct X25519Error;
"x25519 shared secret is all-zero"
}
#[derive(Clone)]
pub struct X25519SecretKey([u8; Self::LENGTH]);
impl PartialEq for X25519SecretKey {
fn eq(&self, other: &Self) -> bool {
ct::constant_time_eq(&self.0, &other.0)
}
}
impl Eq for X25519SecretKey {}
impl X25519SecretKey {
pub const LENGTH: usize = POINT_LENGTH;
#[inline]
#[must_use]
pub const fn from_bytes(bytes: [u8; Self::LENGTH]) -> Self {
Self(bytes)
}
#[inline]
#[must_use]
pub fn expose_secret(&self) -> SecretBytes<{ Self::LENGTH }> {
SecretBytes::new(self.0)
}
#[inline]
#[must_use]
pub const fn as_bytes(&self) -> &[u8; Self::LENGTH] {
&self.0
}
#[inline]
#[must_use]
pub fn generate(fill: impl FnOnce(&mut [u8; Self::LENGTH])) -> Self {
let mut bytes = [0u8; Self::LENGTH];
fill(&mut bytes);
Self(bytes)
}
impl_getrandom!();
#[must_use]
pub fn public_key(&self) -> X25519PublicKey {
public_key_from_scalar(&self.clamped_scalar_bytes())
}
pub fn diffie_hellman(&self, public: &X25519PublicKey) -> Result<X25519SharedSecret, X25519Error> {
X25519SharedSecret::diffie_hellman(self, public)
}
#[inline]
#[must_use]
fn clamped_scalar_bytes(&self) -> [u8; Self::LENGTH] {
let mut scalar = self.0;
clamp_secret_scalar(&mut scalar);
scalar
}
}
impl fmt::Debug for X25519SecretKey {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("X25519SecretKey(****)")
}
}
impl_hex_fmt_secret!(X25519SecretKey);
impl_serde_secret_bytes!(X25519SecretKey);
impl Drop for X25519SecretKey {
fn drop(&mut self) {
ct::zeroize(&mut self.0);
}
}
impl_ct_eq!(X25519SecretKey);
#[derive(Clone, Copy)]
pub struct X25519PublicKey {
bytes: [u8; Self::LENGTH],
u: FieldElement,
}
impl X25519PublicKey {
pub const LENGTH: usize = POINT_LENGTH;
#[must_use]
pub fn from_bytes(bytes: [u8; Self::LENGTH]) -> Self {
Self {
u: decode_u_coordinate(&bytes),
bytes,
}
}
#[inline]
#[must_use]
pub fn basepoint() -> Self {
Self::from_bytes(BASEPOINT_BYTES)
}
#[inline]
#[must_use]
fn from_u(u: FieldElement) -> Self {
Self { bytes: u.to_bytes(), u }
}
#[inline]
#[must_use]
pub const fn to_bytes(self) -> [u8; Self::LENGTH] {
self.bytes
}
#[inline]
#[must_use]
pub const fn as_bytes(&self) -> &[u8; Self::LENGTH] {
&self.bytes
}
}
impl PartialEq for X25519PublicKey {
fn eq(&self, other: &Self) -> bool {
self.bytes == other.bytes
}
}
impl Eq for X25519PublicKey {}
impl Hash for X25519PublicKey {
fn hash<H: Hasher>(&self, state: &mut H) {
self.bytes.hash(state);
}
}
impl fmt::Debug for X25519PublicKey {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "X25519PublicKey(")?;
crate::hex::fmt_hex_lower(&self.bytes, f)?;
write!(f, ")")
}
}
impl_hex_fmt!(X25519PublicKey);
impl_serde_bytes!(X25519PublicKey);
impl_ct_eq!(X25519PublicKey, bytes);
impl From<&X25519SecretKey> for X25519PublicKey {
#[inline]
fn from(secret: &X25519SecretKey) -> Self {
secret.public_key()
}
}
impl From<X25519SecretKey> for X25519PublicKey {
#[inline]
fn from(secret: X25519SecretKey) -> Self {
secret.public_key()
}
}
#[derive(Clone)]
pub struct X25519SharedSecret([u8; Self::LENGTH]);
impl PartialEq for X25519SharedSecret {
fn eq(&self, other: &Self) -> bool {
ct::constant_time_eq(&self.0, &other.0)
}
}
impl Eq for X25519SharedSecret {}
impl X25519SharedSecret {
pub const LENGTH: usize = POINT_LENGTH;
#[inline]
#[must_use]
pub const fn from_bytes(bytes: [u8; Self::LENGTH]) -> Self {
Self(bytes)
}
#[inline]
#[must_use]
pub fn expose_secret(&self) -> SecretBytes<{ Self::LENGTH }> {
SecretBytes::new(self.0)
}
#[inline]
#[must_use]
pub const fn as_bytes(&self) -> &[u8; Self::LENGTH] {
&self.0
}
pub fn diffie_hellman(secret: &X25519SecretKey, public: &X25519PublicKey) -> Result<Self, X25519Error> {
let shared = montgomery_ladder(&secret.clamped_scalar_bytes(), &public.u).to_bytes();
if is_all_zero(&shared) {
Err(X25519Error::new())
} else {
Ok(Self(shared))
}
}
}
impl fmt::Debug for X25519SharedSecret {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("X25519SharedSecret(****)")
}
}
impl_hex_fmt_secret!(X25519SharedSecret);
impl_serde_secret_bytes!(X25519SharedSecret);
impl Drop for X25519SharedSecret {
fn drop(&mut self) {
ct::zeroize(&mut self.0);
}
}
impl_ct_eq!(X25519SharedSecret);
#[allow(clippy::indexing_slicing)]
#[must_use]
fn montgomery_ladder(scalar_bytes: &[u8; POINT_LENGTH], u: &FieldElement) -> FieldElement {
let x1 = *u;
let mut x2 = FieldElement::ONE;
let mut z2 = FieldElement::ZERO;
let mut x3 = x1;
let mut z3 = FieldElement::ONE;
let mut swap = 0u8;
let mut bit = 255usize;
while bit > 0 {
bit = bit.strict_sub(1);
let byte_index = bit >> 3;
let bit_index = bit & 7;
let bit_value = (scalar_bytes[byte_index] >> bit_index) & 1;
swap ^= bit_value;
FieldElement::conditional_swap(&mut x2, &mut x3, swap);
FieldElement::conditional_swap(&mut z2, &mut z3, swap);
swap = bit_value;
let a = x2.add(&z2);
let aa = a.square();
let b = x2.sub(&z2);
let bb = b.square();
let e = aa.sub(&bb);
let c = x3.add(&z3);
let d = x3.sub(&z3);
let da = d.mul(&a);
let cb = c.mul(&b);
let da_plus_cb = da.add(&cb);
let da_minus_cb = da.sub(&cb);
let da_minus_cb_sq = da_minus_cb.square();
x3 = da_plus_cb.square();
z3 = x1.mul(&da_minus_cb_sq);
x2 = aa.mul(&bb);
z2 = e.mul(&aa.add(&A24.mul(&e)));
}
FieldElement::conditional_swap(&mut x2, &mut x3, swap);
FieldElement::conditional_swap(&mut z2, &mut z3, swap);
x2.mul(&z2.invert())
}
#[inline]
#[must_use]
fn public_key_from_scalar(scalar_bytes: &[u8; POINT_LENGTH]) -> X25519PublicKey {
let point = curve25519_edwards::basepoint_mul_dispatch(scalar_bytes);
X25519PublicKey::from_u(point.to_montgomery_u())
}
#[must_use]
fn is_all_zero(bytes: &[u8; POINT_LENGTH]) -> bool {
let mut acc = 0u8;
for &byte in bytes {
acc |= byte;
}
core::hint::black_box(acc) == 0
}
#[must_use]
fn decode_u_coordinate(bytes: &[u8; POINT_LENGTH]) -> FieldElement {
let mut canonical = *bytes;
canonical[POINT_LENGTH - 1] &= 0x7f;
let mut acc = 0u128;
let mut acc_bits = 0u32;
let mut byte_iter = canonical.iter().copied();
let mut limbs = [0u64; 5];
for limb in &mut limbs {
while acc_bits < RADIX_BITS {
let Some(byte) = byte_iter.next() else {
break;
};
acc |= u128::from(byte) << acc_bits;
acc_bits = acc_bits.wrapping_add(8);
}
*limb = (acc & u128::from(MASK51)) as u64;
acc >>= RADIX_BITS;
acc_bits = acc_bits.wrapping_sub(RADIX_BITS);
}
FieldElement::from_limbs(limbs).normalize()
}