use crate::drbg::HmacDrbgSha256;
use noxtls_core::{Error, Result};
const MASK51: u64 = (1_u64 << 51) - 1;
const P: [u64; 5] = [
(1_u64 << 51) - 19,
(1_u64 << 51) - 1,
(1_u64 << 51) - 1,
(1_u64 << 51) - 1,
(1_u64 << 51) - 1,
];
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct X25519PrivateKey {
scalar: [u8; 32],
}
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
pub struct X25519PublicKey {
pub bytes: [u8; 32],
}
impl X25519PrivateKey {
#[must_use]
pub fn from_bytes(bytes: [u8; 32]) -> Self {
Self { scalar: bytes }
}
#[must_use]
pub fn to_bytes(&self) -> [u8; 32] {
self.scalar
}
pub fn clear(&mut self) {
self.scalar.fill(0);
}
#[must_use]
pub fn clamped_scalar(&self) -> [u8; 32] {
clamp_scalar(self.scalar)
}
#[must_use]
pub fn public_key(&self) -> X25519PublicKey {
X25519PublicKey {
bytes: x25519_basepoint(&self.scalar),
}
}
#[must_use]
pub fn diffie_hellman(&self, peer: X25519PublicKey) -> [u8; 32] {
x25519(&self.scalar, &peer.bytes)
}
pub fn diffie_hellman_checked(&self, peer: X25519PublicKey) -> Result<[u8; 32]> {
peer.validate()?;
let shared = self.diffie_hellman(peer);
if is_all_zero(&shared) {
return Err(Error::CryptoFailure("x25519 shared secret is all-zero"));
}
Ok(shared)
}
}
impl Drop for X25519PrivateKey {
fn drop(&mut self) {
self.clear();
}
}
impl X25519PublicKey {
#[must_use]
pub fn from_bytes(bytes: [u8; 32]) -> Self {
Self { bytes }
}
#[must_use]
pub fn is_all_zero(self) -> bool {
is_all_zero(&self.bytes)
}
pub fn validate(self) -> Result<()> {
let masked = self.masked_u_coordinate();
if is_all_zero(&masked) {
return Err(Error::CryptoFailure(
"x25519 peer public key is low-order (masked zero)",
));
}
if is_montgomery_u_one(&masked) {
return Err(Error::CryptoFailure(
"x25519 peer public key is low-order (u=1)",
));
}
Ok(())
}
#[must_use]
fn masked_u_coordinate(self) -> [u8; 32] {
let mut masked = self.bytes;
masked[31] &= 0x7f;
masked
}
}
#[must_use]
pub fn x25519(scalar: &[u8; 32], u: &[u8; 32]) -> [u8; 32] {
let k = clamp_scalar(*scalar);
let mut u_masked = *u;
u_masked[31] &= 0x7f;
let x1 = FieldElement::from_bytes(&u_masked);
let mut x2 = FieldElement::one();
let mut z2 = FieldElement::zero();
let mut x3 = x1;
let mut z3 = FieldElement::one();
let mut swap = 0_u8;
for t in (0..255).rev() {
let k_t = (k[t / 8] >> (t & 7)) & 1;
swap ^= k_t;
FieldElement::cswap(&mut x2, &mut x3, swap);
FieldElement::cswap(&mut z2, &mut z3, swap);
swap = k_t;
let a = x2.add(&z2);
let aa = a.square();
let b = x2.sub(&z2);
let bb = b.square();
let e = aa.sub(&bb);
let c = x3.add(&z3);
let d = x3.sub(&z3);
let da = d.mul(&a);
let cb = c.mul(&b);
x3 = da.add(&cb).square();
z3 = x1.mul(&da.sub(&cb).square());
x2 = aa.mul(&bb);
z2 = e.mul(&aa.add(&e.mul_small(121665)));
}
FieldElement::cswap(&mut x2, &mut x3, swap);
FieldElement::cswap(&mut z2, &mut z3, swap);
x2.mul(&z2.invert()).to_bytes()
}
#[must_use]
pub fn x25519_basepoint(scalar: &[u8; 32]) -> [u8; 32] {
let mut basepoint = [0_u8; 32];
basepoint[0] = 9;
x25519(scalar, &basepoint)
}
pub fn x25519_shared_secret(
private_key: X25519PrivateKey,
peer_public_key: X25519PublicKey,
) -> Result<[u8; 32]> {
private_key.diffie_hellman_checked(peer_public_key)
}
pub fn x25519_generate_private_key_auto(drbg: &mut HmacDrbgSha256) -> Result<X25519PrivateKey> {
let scalar = drbg.generate(32, b"x25519_private_scalar")?;
let bytes: [u8; 32] = scalar
.as_slice()
.try_into()
.map_err(|_| Error::InvalidLength("x25519 private scalar length mismatch"))?;
Ok(X25519PrivateKey::from_bytes(bytes))
}
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
struct FieldElement([u64; 5]);
impl FieldElement {
#[must_use]
fn zero() -> Self {
Self([0; 5])
}
#[must_use]
fn one() -> Self {
Self([1, 0, 0, 0, 0])
}
#[must_use]
fn from_bytes(input: &[u8; 32]) -> Self {
let l0 = load8(input, 0) & MASK51;
let l1 = (load8(input, 6) >> 3) & MASK51;
let l2 = (load8(input, 12) >> 6) & MASK51;
let l3 = (load8(input, 19) >> 1) & MASK51;
let l4 = (load8(input, 24) >> 12) & MASK51;
Self([l0, l1, l2, l3, l4]).carry_reduce()
}
#[must_use]
fn to_bytes(self) -> [u8; 32] {
let h = self.normalize();
let mut out = [0_u8; 32];
for (byte_idx, out_byte) in out.iter_mut().enumerate() {
let mut value = 0_u8;
for bit in 0..8 {
let bit_idx = byte_idx * 8 + bit;
if bit_idx < 255 && h.bit(bit_idx) {
value |= 1 << bit;
}
}
*out_byte = value;
}
out
}
#[must_use]
fn add(&self, rhs: &Self) -> Self {
let mut out = [0_u64; 5];
for (idx, item) in out.iter_mut().enumerate() {
*item = self.0[idx].wrapping_add(rhs.0[idx]);
}
Self(out).carry_reduce()
}
#[must_use]
fn sub(&self, rhs: &Self) -> Self {
let mut out = [0_u64; 5];
for (idx, item) in out.iter_mut().enumerate() {
*item = self.0[idx]
.wrapping_add(P[idx] << 1)
.wrapping_sub(rhs.0[idx]);
}
Self(out).carry_reduce()
}
#[must_use]
fn mul_small(&self, scalar: u64) -> Self {
let mut out = [0_u64; 5];
let mut carry = 0_u128;
for (idx, item) in out.iter_mut().enumerate() {
let v = (self.0[idx] as u128) * (scalar as u128) + carry;
*item = (v as u64) & MASK51;
carry = v >> 51;
}
out[0] = out[0].wrapping_add((carry as u64) * 19);
Self(out).carry_reduce()
}
#[must_use]
fn mul(&self, rhs: &Self) -> Self {
let a = self.0;
let b = rhs.0;
let c0 = (a[0] as u128) * (b[0] as u128)
+ 19 * ((a[1] as u128) * (b[4] as u128)
+ (a[2] as u128) * (b[3] as u128)
+ (a[3] as u128) * (b[2] as u128)
+ (a[4] as u128) * (b[1] as u128));
let c1 = (a[0] as u128) * (b[1] as u128)
+ (a[1] as u128) * (b[0] as u128)
+ 19 * ((a[2] as u128) * (b[4] as u128)
+ (a[3] as u128) * (b[3] as u128)
+ (a[4] as u128) * (b[2] as u128));
let c2 = (a[0] as u128) * (b[2] as u128)
+ (a[1] as u128) * (b[1] as u128)
+ (a[2] as u128) * (b[0] as u128)
+ 19 * ((a[3] as u128) * (b[4] as u128) + (a[4] as u128) * (b[3] as u128));
let c3 = (a[0] as u128) * (b[3] as u128)
+ (a[1] as u128) * (b[2] as u128)
+ (a[2] as u128) * (b[1] as u128)
+ (a[3] as u128) * (b[0] as u128)
+ 19 * ((a[4] as u128) * (b[4] as u128));
let c4 = (a[0] as u128) * (b[4] as u128)
+ (a[1] as u128) * (b[3] as u128)
+ (a[2] as u128) * (b[2] as u128)
+ (a[3] as u128) * (b[1] as u128)
+ (a[4] as u128) * (b[0] as u128);
let mut out = [0_u64; 5];
out[0] = (c0 as u64) & MASK51;
let mut carry = c0 >> 51;
let c1 = c1 + carry;
out[1] = (c1 as u64) & MASK51;
carry = c1 >> 51;
let c2 = c2 + carry;
out[2] = (c2 as u64) & MASK51;
carry = c2 >> 51;
let c3 = c3 + carry;
out[3] = (c3 as u64) & MASK51;
carry = c3 >> 51;
let c4 = c4 + carry;
out[4] = (c4 as u64) & MASK51;
carry = c4 >> 51;
out[0] = out[0].wrapping_add((carry as u64) * 19);
Self(out).carry_reduce()
}
#[must_use]
fn square(&self) -> Self {
self.mul(self)
}
#[must_use]
fn invert(&self) -> Self {
let mut exp = [0xff_u8; 32];
exp[0] = 0xeb;
exp[31] = 0x7f;
let mut base = *self;
let mut result = Self::one();
for i in 0..255 {
if ((exp[i / 8] >> (i & 7)) & 1) == 1 {
result = result.mul(&base);
}
base = base.square();
}
result
}
fn cswap(a: &mut Self, b: &mut Self, choice: u8) {
let mask = 0_u64.wrapping_sub(u64::from(choice));
for i in 0..5 {
let t = mask & (a.0[i] ^ b.0[i]);
a.0[i] ^= t;
b.0[i] ^= t;
}
}
#[must_use]
fn carry_reduce(self) -> Self {
let mut h = self.0;
for _ in 0..2 {
let c0 = h[0] >> 51;
h[0] &= MASK51;
h[1] = h[1].wrapping_add(c0);
let c1 = h[1] >> 51;
h[1] &= MASK51;
h[2] = h[2].wrapping_add(c1);
let c2 = h[2] >> 51;
h[2] &= MASK51;
h[3] = h[3].wrapping_add(c2);
let c3 = h[3] >> 51;
h[3] &= MASK51;
h[4] = h[4].wrapping_add(c3);
let c4 = h[4] >> 51;
h[4] &= MASK51;
h[0] = h[0].wrapping_add(c4 * 19);
}
Self(h)
}
#[must_use]
fn normalize(self) -> Self {
let mut h = self.carry_reduce().0;
let mut t = [0_u64; 5];
let mut borrow = 0_i128;
for i in 0..5 {
let tmp = (h[i] as i128) - (P[i] as i128) - borrow;
if tmp < 0 {
t[i] = (tmp + (1_i128 << 51)) as u64;
borrow = 1;
} else {
t[i] = tmp as u64;
borrow = 0;
}
}
if borrow == 0 {
h = t;
}
Self(h)
}
#[must_use]
fn bit(&self, bit_idx: usize) -> bool {
let limb = bit_idx / 51;
let offset = bit_idx % 51;
((self.0[limb] >> offset) & 1) == 1
}
}
fn load8(input: &[u8; 32], offset: usize) -> u64 {
u64::from_le_bytes(
input[offset..offset + 8]
.try_into()
.expect("slice must be 8 bytes"),
)
}
fn clamp_scalar(mut scalar: [u8; 32]) -> [u8; 32] {
scalar[0] &= 248;
scalar[31] &= 127;
scalar[31] |= 64;
scalar
}
fn is_all_zero(bytes: &[u8; 32]) -> bool {
let mut acc = 0_u8;
for byte in bytes {
acc |= *byte;
}
acc == 0
}
fn is_montgomery_u_one(bytes: &[u8; 32]) -> bool {
bytes[0] == 1 && bytes[1..].iter().all(|byte| *byte == 0)
}