use super::groups::DhGroup;
use crate::bignum::{BoxedMontModulus, BoxedUint};
use crate::ct::ConstantTimeEq;
use crate::rng::RngCore;
use alloc::vec;
use alloc::vec::Vec;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Error {
InvalidPublicKey,
ContributoryFailure,
InvalidGroup,
InvalidScalar,
}
impl core::fmt::Display for Error {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Error::InvalidPublicKey => f.write_str("invalid Diffie-Hellman public value"),
Error::ContributoryFailure => {
f.write_str("Diffie-Hellman shared secret failed contributory check")
}
Error::InvalidGroup => f.write_str("invalid Diffie-Hellman group parameters"),
Error::InvalidScalar => f.write_str("Diffie-Hellman private scalar out of range"),
}
}
}
#[cfg(feature = "std")]
impl std::error::Error for Error {}
#[derive(Clone)]
pub struct DhPrivateKey {
group: DhGroup,
x: BoxedUint,
}
#[derive(Clone)]
pub struct DhPublicKey {
group: DhGroup,
y: BoxedUint,
}
pub struct SharedSecret {
bytes: Vec<u8>,
}
impl SharedSecret {
pub fn as_bytes(&self) -> &[u8] {
&self.bytes
}
pub fn into_bytes(self) -> Vec<u8> {
self.bytes
}
}
impl DhPrivateKey {
pub fn generate<R: RngCore>(group: DhGroup, rng: &mut R) -> Self {
let priv_bits = group.priv_bits;
let nbytes = priv_bits.div_ceil(8);
let mut bytes = vec![0u8; nbytes];
rng.fill_bytes(&mut bytes);
let high_bits = priv_bits - (nbytes - 1) * 8; let mask: u8 = if high_bits == 8 {
0xFF
} else {
(1u8 << high_bits) - 1
};
bytes[0] &= mask;
bytes[0] |= 1 << (high_bits - 1);
let x = BoxedUint::from_be_bytes(&bytes);
DhPrivateKey { group, x }
}
pub fn from_bytes(group: DhGroup, bytes: &[u8]) -> Result<Self, Error> {
let x = BoxedUint::from_be_bytes(bytes);
if x.is_zero() || !x.lt(group.p()) {
return Err(Error::InvalidScalar);
}
Ok(DhPrivateKey { group, x })
}
pub fn public_key(&self) -> DhPublicKey {
let m = BoxedMontModulus::new(self.group.p());
let y = m.pow(self.group.g(), &self.x);
DhPublicKey {
group: self.group.clone(),
y,
}
}
pub fn shared_secret(&self, peer: &DhPublicKey) -> Result<SharedSecret, Error> {
let p = self.group.p();
let two = BoxedUint::from_u64(2);
let p_minus_one = p.sub(&BoxedUint::from_u64(1));
if peer.y.lt(&two) || !peer.y.lt(&p_minus_one) {
return Err(Error::InvalidPublicKey);
}
let m = BoxedMontModulus::new(p);
let z = m.pow(&peer.y, &self.x);
let zero_eq = z.ct_eq(&BoxedUint::from_u64(0));
let one_eq = z.ct_eq(&BoxedUint::from_u64(1));
if bool::from(zero_eq | one_eq) {
return Err(Error::ContributoryFailure);
}
let bytes = z.to_be_bytes(self.group.byte_size());
Ok(SharedSecret { bytes })
}
pub fn group(&self) -> &DhGroup {
&self.group
}
pub fn to_bytes(&self) -> Vec<u8> {
self.x.to_be_bytes(self.group.byte_size())
}
}
impl DhPublicKey {
pub fn to_bytes(&self) -> Vec<u8> {
self.y.to_be_bytes(self.group.byte_size())
}
pub fn from_bytes(group: DhGroup, bytes: &[u8]) -> Result<Self, Error> {
let y = BoxedUint::from_be_bytes(bytes);
let two = BoxedUint::from_u64(2);
let p_minus_one = group.p().sub(&BoxedUint::from_u64(1));
if y.lt(&two) || !y.lt(&p_minus_one) {
return Err(Error::InvalidPublicKey);
}
Ok(DhPublicKey { group, y })
}
pub fn group(&self) -> &DhGroup {
&self.group
}
pub fn y(&self) -> &BoxedUint {
&self.y
}
}
#[cfg(test)]
mod tests {
use super::super::groups::{DhGroup, group14, group15, group16};
use super::*;
use crate::hash::Sha256;
use crate::rng::HmacDrbg;
#[test]
fn group14_keyx_roundtrip() {
let mut rng = HmacDrbg::<Sha256>::new(b"dh-group14", b"nonce", &[]);
let alice = DhPrivateKey::generate(group14(), &mut rng);
let bob = DhPrivateKey::generate(group14(), &mut rng);
let a_shared = alice.shared_secret(&bob.public_key()).unwrap();
let b_shared = bob.shared_secret(&alice.public_key()).unwrap();
assert_eq!(a_shared.as_bytes(), b_shared.as_bytes());
assert_eq!(a_shared.as_bytes().len(), 256);
assert!(a_shared.as_bytes().iter().any(|&b| b != 0));
}
#[test]
fn group15_keyx_roundtrip() {
let mut rng = HmacDrbg::<Sha256>::new(b"dh-group15", b"nonce", &[]);
let alice = DhPrivateKey::generate(group15(), &mut rng);
let bob = DhPrivateKey::generate(group15(), &mut rng);
let a_shared = alice.shared_secret(&bob.public_key()).unwrap();
let b_shared = bob.shared_secret(&alice.public_key()).unwrap();
assert_eq!(a_shared.as_bytes(), b_shared.as_bytes());
assert_eq!(a_shared.as_bytes().len(), 384);
}
#[test]
fn group16_keyx_roundtrip() {
let mut rng = HmacDrbg::<Sha256>::new(b"dh-group16", b"nonce", &[]);
let alice = DhPrivateKey::generate(group16(), &mut rng);
let bob = DhPrivateKey::generate(group16(), &mut rng);
let a_shared = alice.shared_secret(&bob.public_key()).unwrap();
let b_shared = bob.shared_secret(&alice.public_key()).unwrap();
assert_eq!(a_shared.as_bytes(), b_shared.as_bytes());
assert_eq!(a_shared.as_bytes().len(), 512);
}
#[test]
#[ignore]
fn group18_keyx_roundtrip() {
use super::super::groups::{group17, group18};
let mut rng = HmacDrbg::<Sha256>::new(b"dh-group17-18", b"nonce", &[]);
let alice = DhPrivateKey::generate(group17(), &mut rng);
let bob = DhPrivateKey::generate(group17(), &mut rng);
let a = alice.shared_secret(&bob.public_key()).unwrap();
let b = bob.shared_secret(&alice.public_key()).unwrap();
assert_eq!(a.as_bytes(), b.as_bytes());
assert_eq!(a.as_bytes().len(), 768);
let alice = DhPrivateKey::generate(group18(), &mut rng);
let bob = DhPrivateKey::generate(group18(), &mut rng);
let a = alice.shared_secret(&bob.public_key()).unwrap();
let b = bob.shared_secret(&alice.public_key()).unwrap();
assert_eq!(a.as_bytes(), b.as_bytes());
assert_eq!(a.as_bytes().len(), 1024);
}
fn expect_invalid_pub(r: Result<DhPublicKey, Error>) {
match r {
Err(Error::InvalidPublicKey) => {}
Err(other) => panic!("expected InvalidPublicKey, got {other:?}"),
Ok(_) => panic!("expected InvalidPublicKey, got Ok"),
}
}
fn expect_invalid_scalar(r: Result<DhPrivateKey, Error>) {
match r {
Err(Error::InvalidScalar) => {}
Err(other) => panic!("expected InvalidScalar, got {other:?}"),
Ok(_) => panic!("expected InvalidScalar, got Ok"),
}
}
#[test]
fn rejects_invalid_public_key_zero() {
let buf = vec![0u8; 256];
expect_invalid_pub(DhPublicKey::from_bytes(group14(), &buf));
}
#[test]
fn rejects_invalid_public_key_one() {
let mut buf = vec![0u8; 256];
buf[255] = 1;
expect_invalid_pub(DhPublicKey::from_bytes(group14(), &buf));
}
#[test]
fn rejects_invalid_public_key_p_minus_one() {
let g = group14();
let pm1 = g.p().sub(&BoxedUint::from_u64(1));
let buf = pm1.to_be_bytes(256);
expect_invalid_pub(DhPublicKey::from_bytes(g, &buf));
}
#[test]
fn rejects_invalid_public_key_ge_p() {
let buf = group14().p().to_be_bytes(256);
expect_invalid_pub(DhPublicKey::from_bytes(group14(), &buf));
let mut extended = vec![0u8; 257];
extended[1..].copy_from_slice(&buf);
let plus_one = BoxedUint::from_be_bytes(&extended).add(&BoxedUint::from_u64(1));
let plus_one_bytes = plus_one.to_be_bytes(257);
expect_invalid_pub(DhPublicKey::from_bytes(group14(), &plus_one_bytes));
}
#[test]
fn from_bytes_round_trip_public_key() {
let mut rng = HmacDrbg::<Sha256>::new(b"dh-roundtrip", b"nonce", &[]);
let alice = DhPrivateKey::generate(group14(), &mut rng);
let pk = alice.public_key();
let bytes = pk.to_bytes();
let pk2 = DhPublicKey::from_bytes(group14(), &bytes).unwrap();
assert_eq!(pk.to_bytes(), pk2.to_bytes());
}
#[test]
fn group_exchange_custom_group() {
let p = group14().p().clone();
let g = group14().g().clone();
let custom = DhGroup::from_custom(p, g, 256).expect("from_custom accepts group14 (p, g)");
assert_eq!(custom.name(), "custom");
assert_eq!(custom.bit_size(), 2048);
let mut rng = HmacDrbg::<Sha256>::new(b"dh-custom", b"nonce", &[]);
let alice = DhPrivateKey::generate(custom.clone(), &mut rng);
let bob = DhPrivateKey::generate(custom, &mut rng);
let a = alice.shared_secret(&bob.public_key()).unwrap();
let b = bob.shared_secret(&alice.public_key()).unwrap();
assert_eq!(a.as_bytes(), b.as_bytes());
}
#[test]
fn from_bytes_rejects_out_of_range_scalar() {
let buf = vec![0u8; 256];
expect_invalid_scalar(DhPrivateKey::from_bytes(group14(), &buf));
let buf = group14().p().to_be_bytes(256);
expect_invalid_scalar(DhPrivateKey::from_bytes(group14(), &buf));
let mut buf = vec![0u8; 256];
buf[255] = 1;
assert!(DhPrivateKey::from_bytes(group14(), &buf).is_ok());
}
#[test]
fn shared_secret_byte_length_matches_prime() {
let mut rng = HmacDrbg::<Sha256>::new(b"dh-len", b"nonce", &[]);
let alice = DhPrivateKey::generate(group14(), &mut rng);
let bob = DhPrivateKey::generate(group14(), &mut rng);
let s = alice.shared_secret(&bob.public_key()).unwrap();
assert_eq!(s.as_bytes().len(), group14().p().bit_len().div_ceil(8));
}
#[test]
fn known_small_dh_via_custom_group() {
let p = BoxedUint::from_u64(23);
let g = BoxedUint::from_u64(5);
let group = DhGroup::from_custom(p.clone(), g.clone(), 4).unwrap();
let mut a_buf = vec![0u8];
a_buf[0] = 6;
let alice = DhPrivateKey::from_bytes(group.clone(), &a_buf).unwrap();
let a_pub = alice.public_key();
assert_eq!(a_pub.y(), &BoxedUint::from_u64(8));
let mut b_buf = vec![0u8];
b_buf[0] = 15;
let bob = DhPrivateKey::from_bytes(group, &b_buf).unwrap();
let b_pub = bob.public_key();
assert_eq!(b_pub.y(), &BoxedUint::from_u64(19));
let a_shared = alice.shared_secret(&b_pub).unwrap();
let b_shared = bob.shared_secret(&a_pub).unwrap();
assert_eq!(a_shared.as_bytes(), b_shared.as_bytes());
assert_eq!(a_shared.as_bytes(), &[2u8]);
}
}