use ed25519_dalek::{
SECRET_KEY_LENGTH, SigningKey as Ed25519SigningKey, VerifyingKey as Ed25519VerifyingKey,
};
use hkdf::Hkdf;
use ml_dsa::{B32, EncodedVerifyingKey, KeyGen as _, KeyPair, MlDsa87, SigningKey, VerifyingKey};
use rand::SeedableRng as _;
use rand_chacha::ChaCha20Rng;
use rand_chacha::rand_core::CryptoRngCore;
use subtle::ConstantTimeEq as _;
use zeroize::{Zeroize, ZeroizeOnDrop};
use std::io::{Cursor, ErrorKind, Read, Write};
use curve25519_dalek::montgomery::MontgomeryPoint;
use ml_kem::EncodedSizeUser;
use sha2::Sha512;
use x25519_dalek::{PublicKey, StaticSecret};
use crate::base64::{base64_decode, base64_encode};
pub use crate::crypto::hybrid::{MLADecryptionPrivateKey, MLAEncryptionPublicKey};
use crate::crypto::hybrid::{MLKEM_DZ_SIZE, MLKEMEncapsulationKey, MLKEMSeed};
use crate::errors::Error;
use crate::layers::encrypt::get_crypto_rng;
use crate::{EMPTY_OPTS_SERIALIZATION, MLADeserialize};
use super::hybrid::generate_keypair_from_rng;
const MLA_PRIV_DEC_KEY_HEADER: &[u8] = b"MLA PRIVATE DECRYPTION KEY ";
const MLA_PRIV_SIG_KEY_HEADER: &[u8] = b"MLA PRIVATE SIGNING KEY ";
const DEC_METHOD_ID_0_PRIV: &[u8] = b"mla-kem-private-x25519-mlkem1024";
const SIG_METHOD_ID_0_PRIV: &[u8] = b"mla-signature-private-ed25519-mldsa87";
const MLA_PUB_ENC_KEY_HEADER: &[u8] = b"MLA PUBLIC ENCRYPTION KEY ";
const MLA_PUB_SIGVERIF_KEY_HEADER: &[u8] = b"MLA PUBLIC SIGNATURE VERIFICATION KEY ";
const ENC_METHOD_ID_0_PUB: &[u8] = b"mla-kem-public-x25519-mlkem1024";
const SIGVERIF_METHOD_ID_0_PUB: &[u8] = b"mla-signature-verification-public-ed25519-mldsa87";
const ED25519_PRIVKEY_SIZE: usize = 32;
const PRIV_KEY_FILE_HEADER: &[u8] = b"DO NOT SEND THIS TO ANYONE - MLA PRIVATE KEY FILE V1";
const PRIV_KEY_FILE_FOOTER: &[u8] = b"END OF MLA PRIVATE KEY FILE";
const PUB_KEY_FILE_HEADER: &[u8] = b"MLA PUBLIC KEY FILE V1";
const PUB_KEY_FILE_FOOTER: &[u8] = b"END OF MLA PUBLIC KEY FILE";
#[allow(clippy::slow_vector_initialization, clippy::manual_memcpy)]
fn zeroizeable_read_to_end(mut src: impl Read) -> Result<Vec<u8>, Error> {
let mut buf = Vec::new();
let mut min_capacity: usize = 4096; buf.resize(min_capacity, 0);
let mut read_offset: usize = 0; loop {
if read_offset == min_capacity {
min_capacity = min_capacity
.checked_mul(2)
.ok_or(Error::DeserializationError)?;
let mut new_buf = Vec::new();
new_buf.resize(min_capacity, 0);
for i in 0..buf.len() {
new_buf[i] = buf[i];
}
buf.zeroize();
buf = new_buf;
}
match src.read(&mut buf[read_offset..min_capacity]) {
Ok(n) => {
if n == 0 {
buf.resize(read_offset, 0);
return Ok(buf);
}
read_offset = read_offset
.checked_add(n)
.ok_or(Error::DeserializationError)?;
}
Err(e) if e.kind() == ErrorKind::Interrupted => {}
Err(_) => {
buf.zeroize();
return Err(Error::DeserializationError);
}
}
}
}
#[allow(clippy::needless_range_loop)]
fn split_in_five_parts_without_buffering(content: &[u8]) -> Result<[&[u8]; 5], Error> {
fn split_separator(content: &[u8]) -> Result<(&[u8], &[u8]), Error> {
let mut carriage_return_index = 0;
for i in 0..content.len() {
if content[i] == b'\r' || content[i] == b'_' {
carriage_return_index = i;
break;
}
}
if carriage_return_index == 0 {
return Err(Error::DeserializationError);
}
let (part, rest) = content.split_at(carriage_return_index);
if rest.len() < 2
|| (rest[0] != b'\r' && rest[0] != b'_')
|| (rest[0] == b'\r' && rest[1] != b'\n')
|| (rest[0] == b'_' && rest[1] != b'_')
{
return Err(Error::DeserializationError);
}
let rest = &rest[2..];
Ok((part, rest))
}
let (first_part, rest) = split_separator(content)?;
let (second_part, rest) = split_separator(rest)?;
let (third_part, rest) = split_separator(rest)?;
let (fourth_part, rest) = split_separator(rest)?;
let (fifth_part, rest) = split_separator(rest)?;
if !rest.is_empty() {
return Err(Error::DeserializationError);
}
Ok([first_part, second_part, third_part, fourth_part, fifth_part])
}
#[derive(Clone)]
struct KeyOpts;
impl KeyOpts {
#[allow(clippy::unused_self)]
fn serialize_key_opts<W: Write>(&self, mut dst: W) -> Result<(), Error> {
let encoded = base64_encode(EMPTY_OPTS_SERIALIZATION)?;
dst.write_all(&encoded)?;
dst.write_all(b"\r\n")?;
Ok(())
}
}
impl<R: Read> MLADeserialize<R> for KeyOpts {
fn deserialize(src: &mut R) -> Result<Self, Error> {
const BUFFER_SIZE: usize = 4096;
let discriminant = u8::deserialize(src)?;
match discriminant {
0 => Ok(KeyOpts),
1 => {
let mut key_opts_len = [0; 8];
src.read_exact(&mut key_opts_len)?;
let key_opts_len = usize::try_from(u64::from_le_bytes(key_opts_len))
.map_err(|_| Error::DeserializationError)?;
let number_of_chunks = key_opts_len / BUFFER_SIZE;
let last_chunk_remainder = key_opts_len % BUFFER_SIZE;
let mut buffer = [0; BUFFER_SIZE];
for _ in 0..number_of_chunks {
let result = src.read_exact(&mut buffer);
buffer.zeroize();
if result.is_err() {
return Err(Error::DeserializationError);
}
}
{
let mut buffer = vec![0; last_chunk_remainder];
let result = src.read_exact(&mut buffer);
buffer.zeroize();
if result.is_err() {
return Err(Error::DeserializationError);
}
}
Ok(KeyOpts)
}
_ => Err(Error::DeserializationError),
}
}
}
impl MLADecryptionPrivateKey {
fn deserialize_decryption_private_key(line: &[u8]) -> Result<Self, Error> {
let b64data = line
.strip_prefix(MLA_PRIV_DEC_KEY_HEADER)
.ok_or(Error::DeserializationError)?;
let data = base64_decode(b64data).map_err(|_| Error::DeserializationError)?;
let mut cursor = Cursor::new(data);
let mut method_id = [0; DEC_METHOD_ID_0_PRIV.len()];
cursor
.read_exact(&mut method_id)
.map_err(|_| Error::DeserializationError)?;
if method_id.as_slice() != DEC_METHOD_ID_0_PRIV {
return Err(Error::DeserializationError);
}
let _opts = KeyOpts::deserialize(&mut cursor)?;
let mut serialized_ecc_key = [0; ECC_PRIVKEY_SIZE];
cursor
.read_exact(&mut serialized_ecc_key)
.map_err(|_| Error::DeserializationError)?;
let private_key_ecc = StaticSecret::from(serialized_ecc_key);
serialized_ecc_key.zeroize();
let mut serialized_mlkem_seed = [0; MLKEM_DZ_SIZE];
cursor
.read_exact(&mut serialized_mlkem_seed)
.map_err(|_| Error::DeserializationError)?;
let private_key_seed_ml = MLKEMSeed::from_d_z_64(serialized_mlkem_seed);
cursor.into_inner().zeroize();
Ok(Self {
private_key_ecc,
private_key_seed_ml,
})
}
fn serialize_decryption_private_key<W: Write>(&self, mut dst: W) -> Result<(), Error> {
dst.write_all(MLA_PRIV_DEC_KEY_HEADER)?;
let mut b64data = Vec::with_capacity(
DEC_METHOD_ID_0_PRIV
.len()
.checked_add(EMPTY_OPTS_SERIALIZATION.len())
.unwrap()
.checked_add(32)
.unwrap()
.checked_add(64)
.unwrap(),
);
b64data.extend_from_slice(DEC_METHOD_ID_0_PRIV);
b64data.extend_from_slice(EMPTY_OPTS_SERIALIZATION);
b64data.extend_from_slice(self.private_key_ecc.as_bytes());
b64data.extend_from_slice(self.private_key_seed_ml.to_d_z_64().as_ref());
let mut encoded = base64_encode(&b64data)?;
b64data.zeroize();
dst.write_all(&encoded)?;
encoded.zeroize();
dst.write_all(b"\r\n")?;
Ok(())
}
}
#[derive(Clone)]
pub(crate) struct MLDSASeed {
xi: B32,
}
const MLDSA_XI_SIZE: usize = 32;
impl MLDSASeed {
fn as_slice(&self) -> &[u8] {
self.xi.as_slice()
}
fn generate_from_csprng(mut csprng: impl CryptoRngCore) -> Self {
let mut xi_array = [0u8; 32];
csprng.fill_bytes(&mut xi_array);
Self::from_xi_32(xi_array)
}
fn key_gen_internal(xi: &B32) -> KeyPair<MlDsa87> {
#[cfg(windows)]
{
use std::thread;
#[allow(clippy::clone_on_copy)]
let mut xi = xi.clone();
let builder = thread::Builder::new().stack_size(8 * 1024 * 1024);
let handle = builder
.spawn(move || {
let result = MlDsa87::key_gen_internal(&xi);
xi.zeroize();
result
})
.expect("Failed to spawn thread with increased stack");
handle.join().expect("Thread panicked")
}
#[cfg(not(windows))]
{
MlDsa87::key_gen_internal(xi)
}
}
fn from_xi_32(xi: [u8; 32]) -> Self {
let xi = B32::from(xi);
Self { xi }
}
pub(crate) fn to_signing_key(&self) -> SigningKey<MlDsa87> {
Self::key_gen_internal(&self.xi).signing_key().clone()
}
pub(crate) fn to_signing_verification_key(&self) -> VerifyingKey<MlDsa87> {
Self::key_gen_internal(&self.xi).verifying_key().clone()
}
}
impl PartialEq for MLDSASeed {
fn eq(&self, other: &Self) -> bool {
self.xi.ct_eq(&other.xi).into()
}
}
impl Zeroize for MLDSASeed {
fn zeroize(&mut self) {
self.xi.zeroize();
}
}
impl Drop for MLDSASeed {
fn drop(&mut self) {
self.zeroize();
}
}
impl ZeroizeOnDrop for MLDSASeed {}
#[derive(Clone)]
pub struct MLASigningPrivateKey {
pub(crate) private_key_ed25519: Ed25519SigningKey,
pub(crate) private_key_seed_mldsa: MLDSASeed,
#[allow(dead_code)]
opts: KeyOpts,
}
impl MLASigningPrivateKey {
fn deserialize_signing_private_key(line: &[u8]) -> Result<Self, Error> {
let b64data = line
.strip_prefix(MLA_PRIV_SIG_KEY_HEADER)
.ok_or(Error::DeserializationError)?;
let data = base64_decode(b64data).map_err(|_| Error::DeserializationError)?;
let mut cursor = Cursor::new(data);
let mut method_id = [0; SIG_METHOD_ID_0_PRIV.len()];
cursor
.read_exact(&mut method_id)
.map_err(|_| Error::DeserializationError)?;
if method_id.as_slice() != SIG_METHOD_ID_0_PRIV {
return Err(Error::DeserializationError);
}
let _opts = KeyOpts::deserialize(&mut cursor)?;
let mut serialized_ecc_key = [0; ED25519_PRIVKEY_SIZE];
cursor
.read_exact(&mut serialized_ecc_key)
.map_err(|_| Error::DeserializationError)?;
let private_key_ed25519 = Ed25519SigningKey::from_bytes(&serialized_ecc_key);
serialized_ecc_key.zeroize();
let mut serialized_mldsa_seed = [0; MLDSA_XI_SIZE];
cursor
.read_exact(&mut serialized_mldsa_seed)
.map_err(|_| Error::DeserializationError)?;
let private_key_mldsa87 = MLDSASeed::from_xi_32(serialized_mldsa_seed);
cursor.into_inner().zeroize();
Ok(Self {
private_key_ed25519,
private_key_seed_mldsa: private_key_mldsa87,
opts: KeyOpts,
})
}
fn serialize_signing_private_key<W: Write>(&self, mut dst: W) -> Result<(), Error> {
dst.write_all(MLA_PRIV_SIG_KEY_HEADER)?;
let mut b64data = Vec::with_capacity(
SIG_METHOD_ID_0_PRIV
.len()
.checked_add(EMPTY_OPTS_SERIALIZATION.len())
.unwrap()
.checked_add(SECRET_KEY_LENGTH)
.unwrap()
.checked_add(32)
.unwrap(),
);
b64data.extend_from_slice(SIG_METHOD_ID_0_PRIV);
b64data.extend_from_slice(EMPTY_OPTS_SERIALIZATION);
b64data.extend_from_slice(self.private_key_ed25519.as_bytes());
b64data.extend_from_slice(self.private_key_seed_mldsa.as_slice());
let mut encoded = base64_encode(&b64data)?;
b64data.zeroize();
dst.write_all(&encoded)?;
encoded.zeroize();
dst.write_all(b"\r\n")?;
Ok(())
}
}
impl Drop for MLASigningPrivateKey {
fn drop(&mut self) {
}
}
#[derive(Clone)]
pub struct MLAPrivateKey {
decryption_private_key: MLADecryptionPrivateKey,
signing_private_key: MLASigningPrivateKey,
opts: KeyOpts,
}
impl MLAPrivateKey {
pub fn deserialize_private_key(src: impl Read) -> Result<Self, Error> {
let mut content = zeroizeable_read_to_end(src)?;
let lines = split_in_five_parts_without_buffering(&content)?;
if lines[0] != PRIV_KEY_FILE_HEADER {
return Err(Error::DeserializationError);
}
if lines[4] != PRIV_KEY_FILE_FOOTER {
return Err(Error::DeserializationError);
}
let decryption_private_key =
MLADecryptionPrivateKey::deserialize_decryption_private_key(lines[1])?;
let signing_private_key = MLASigningPrivateKey::deserialize_signing_private_key(lines[2])?;
content.zeroize();
Ok(Self {
decryption_private_key,
signing_private_key,
opts: KeyOpts,
})
}
pub fn from_decryption_and_signature_keys(
decryption_private_key: MLADecryptionPrivateKey,
signing_private_key: MLASigningPrivateKey,
) -> Self {
MLAPrivateKey {
decryption_private_key,
signing_private_key,
opts: KeyOpts,
}
}
pub fn get_decryption_private_key(&self) -> &MLADecryptionPrivateKey {
&self.decryption_private_key
}
pub fn get_private_keys(self) -> (MLADecryptionPrivateKey, MLASigningPrivateKey) {
(self.decryption_private_key, self.signing_private_key)
}
pub fn get_signing_private_key(&self) -> &MLASigningPrivateKey {
&self.signing_private_key
}
pub fn serialize_private_key<W: Write>(&self, mut dst: W) -> Result<(), Error> {
dst.write_all(PRIV_KEY_FILE_HEADER)?;
dst.write_all(b"\r\n")?;
self.decryption_private_key
.serialize_decryption_private_key(&mut dst)?;
self.signing_private_key
.serialize_signing_private_key(&mut dst)?;
self.opts.serialize_key_opts(&mut dst)?;
dst.write_all(PRIV_KEY_FILE_FOOTER)?;
dst.write_all(b"\r\n")?;
Ok(())
}
}
impl MLAEncryptionPublicKey {
fn deserialize_encryption_public_key(line: &[u8]) -> Result<Self, Error> {
let b64data = line
.strip_prefix(MLA_PUB_ENC_KEY_HEADER)
.ok_or(Error::DeserializationError)?;
let data = base64_decode(b64data).map_err(|_| Error::DeserializationError)?;
let mut cursor = Cursor::new(data);
let mut method_id = [0; ENC_METHOD_ID_0_PUB.len()];
cursor
.read_exact(&mut method_id)
.map_err(|_| Error::DeserializationError)?;
if method_id.as_slice() != ENC_METHOD_ID_0_PUB {
return Err(Error::DeserializationError);
}
let _opts = KeyOpts::deserialize(&mut cursor)?;
let mut serialized_ecc_key = [0; ECC_PUBKEY_SIZE];
cursor
.read_exact(&mut serialized_ecc_key)
.map_err(|_| Error::DeserializationError)?;
let public_key_ecc = PublicKey::from(MontgomeryPoint(serialized_ecc_key).to_bytes());
let mut serialized_mlkem_key = Vec::new();
cursor
.read_to_end(&mut serialized_mlkem_key)
.map_err(|_| Error::DeserializationError)?;
let public_key_ml = MLKEMEncapsulationKey::from_bytes(
serialized_mlkem_key
.as_slice()
.try_into()
.map_err(|_| Error::DeserializationError)?,
);
Ok(Self {
public_key_ecc,
public_key_ml,
})
}
fn serialize_encryption_public_key<W: Write>(&self, mut dst: W) -> Result<(), Error> {
dst.write_all(MLA_PUB_ENC_KEY_HEADER)?;
let mut b64data = vec![];
b64data.extend_from_slice(ENC_METHOD_ID_0_PUB);
b64data.extend_from_slice(EMPTY_OPTS_SERIALIZATION); b64data.extend_from_slice(&self.public_key_ecc.to_bytes());
b64data.extend_from_slice(&self.public_key_ml.as_bytes());
dst.write_all(&(base64_encode(&b64data)?))?;
dst.write_all(b"\r\n")?;
Ok(())
}
}
#[derive(Clone)]
pub struct MLASignatureVerificationPublicKey {
pub(crate) public_key_ed25519: Ed25519VerifyingKey,
pub(crate) public_key_mldsa87: VerifyingKey<MlDsa87>,
#[allow(dead_code)]
opts: KeyOpts,
}
impl MLASignatureVerificationPublicKey {
fn deserialize_signature_verification_public_key(line: &[u8]) -> Result<Self, Error> {
let b64data = line
.strip_prefix(MLA_PUB_SIGVERIF_KEY_HEADER)
.ok_or(Error::DeserializationError)?;
let data = base64_decode(b64data).map_err(|_| Error::DeserializationError)?;
let mut cursor = Cursor::new(data);
let mut method_id = [0; SIGVERIF_METHOD_ID_0_PUB.len()];
cursor
.read_exact(&mut method_id)
.map_err(|_| Error::DeserializationError)?;
if method_id.as_slice() != SIGVERIF_METHOD_ID_0_PUB {
return Err(Error::DeserializationError);
}
let _opts = KeyOpts::deserialize(&mut cursor)?;
let mut serialized_ecc_key = [0; ECC_PUBKEY_SIZE];
cursor
.read_exact(&mut serialized_ecc_key)
.map_err(|_| Error::DeserializationError)?;
let public_key_ed25519 = Ed25519VerifyingKey::from_bytes(&serialized_ecc_key)
.map_err(|_| Error::DeserializationError)?;
let mut serialized_mldsa_key = Vec::new();
cursor
.read_to_end(&mut serialized_mldsa_key)
.map_err(|_| Error::DeserializationError)?;
let encoded_signing_mldsa87_key =
EncodedVerifyingKey::<MlDsa87>::try_from(serialized_mldsa_key.as_slice())
.map_err(|_| Error::DeserializationError)?;
let public_key_mldsa87 = VerifyingKey::<MlDsa87>::decode(&encoded_signing_mldsa87_key);
Ok(Self {
public_key_ed25519,
public_key_mldsa87,
opts: KeyOpts,
})
}
fn serialize_signature_verification_public_key<W: Write>(
&self,
mut dst: W,
) -> Result<(), Error> {
dst.write_all(MLA_PUB_SIGVERIF_KEY_HEADER)?;
let mut b64data = vec![];
b64data.extend_from_slice(SIGVERIF_METHOD_ID_0_PUB);
b64data.extend_from_slice(EMPTY_OPTS_SERIALIZATION); b64data.extend_from_slice(self.public_key_ed25519.as_bytes());
b64data.extend_from_slice(self.public_key_mldsa87.encode().as_slice());
let mut encoded = base64_encode(&b64data)?;
dst.write_all(&encoded)?;
encoded.zeroize();
dst.write_all(b"\r\n")?;
Ok(())
}
}
#[derive(Clone)]
pub struct MLAPublicKey {
encryption_public_key: MLAEncryptionPublicKey,
signature_verification_public_key: MLASignatureVerificationPublicKey,
opts: KeyOpts,
}
impl MLAPublicKey {
pub fn deserialize_public_key(src: impl Read) -> Result<Self, Error> {
let mut content = zeroizeable_read_to_end(src)?;
let lines = split_in_five_parts_without_buffering(&content)?;
if lines[0] != PUB_KEY_FILE_HEADER {
return Err(Error::DeserializationError);
}
if lines[4] != PUB_KEY_FILE_FOOTER {
return Err(Error::DeserializationError);
}
let encryption_public_key =
MLAEncryptionPublicKey::deserialize_encryption_public_key(lines[1])?;
let signature_verification_public_key =
MLASignatureVerificationPublicKey::deserialize_signature_verification_public_key(
lines[2],
)?;
content.zeroize();
Ok(Self {
encryption_public_key,
signature_verification_public_key,
opts: KeyOpts,
})
}
pub fn from_encryption_and_signature_verification_keys(
encryption_public_key: MLAEncryptionPublicKey,
signature_verification_public_key: MLASignatureVerificationPublicKey,
) -> Self {
MLAPublicKey {
encryption_public_key,
signature_verification_public_key,
opts: KeyOpts,
}
}
pub fn get_encryption_public_key(&self) -> &MLAEncryptionPublicKey {
&self.encryption_public_key
}
pub fn get_public_keys(self) -> (MLAEncryptionPublicKey, MLASignatureVerificationPublicKey) {
(
self.encryption_public_key,
self.signature_verification_public_key,
)
}
pub fn get_signature_verification_public_key(&self) -> &MLASignatureVerificationPublicKey {
&self.signature_verification_public_key
}
pub fn serialize_public_key<W: Write>(&self, mut dst: W) -> Result<(), Error> {
dst.write_all(PUB_KEY_FILE_HEADER)?;
dst.write_all(b"\r\n")?;
self.encryption_public_key
.serialize_encryption_public_key(&mut dst)?;
self.signature_verification_public_key
.serialize_signature_verification_public_key(&mut dst)?;
self.opts.serialize_key_opts(&mut dst)?;
dst.write_all(PUB_KEY_FILE_FOOTER)?;
dst.write_all(b"\r\n")?;
Ok(())
}
}
fn generate_signature_keypair_from_rng(
mut csprng: impl CryptoRngCore,
) -> (MLASigningPrivateKey, MLASignatureVerificationPublicKey) {
let private_key_ed25519 = Ed25519SigningKey::generate(&mut csprng);
let public_key_ed25519 = private_key_ed25519.verifying_key();
let private_key_seed_mldsa = MLDSASeed::generate_from_csprng(&mut csprng);
let public_key_mldsa87 = private_key_seed_mldsa.to_signing_verification_key();
let privkey = MLASigningPrivateKey {
private_key_ed25519,
private_key_seed_mldsa,
opts: KeyOpts,
};
let pubkey = MLASignatureVerificationPublicKey {
public_key_ed25519,
public_key_mldsa87,
opts: KeyOpts,
};
(privkey, pubkey)
}
pub fn generate_mla_keypair() -> Result<(MLAPrivateKey, MLAPublicKey), Error> {
Ok(generate_mla_keypair_from_rng(get_crypto_rng()?))
}
pub fn generate_mla_keypair_from_seed(seed: [u8; 32]) -> (MLAPrivateKey, MLAPublicKey) {
let csprng = ChaCha20Rng::from_seed(seed);
generate_mla_keypair_from_rng(csprng)
}
fn generate_mla_keypair_from_rng(mut csprng: impl CryptoRngCore) -> (MLAPrivateKey, MLAPublicKey) {
let (decryption_private_key, encryption_public_key) = generate_keypair_from_rng(&mut csprng);
let (signing_private_key, signature_verification_public_key) =
generate_signature_keypair_from_rng(&mut csprng);
let priv_key = MLAPrivateKey {
decryption_private_key,
signing_private_key,
opts: KeyOpts,
};
let pub_key = MLAPublicKey {
encryption_public_key,
signature_verification_public_key,
opts: KeyOpts,
};
(priv_key, pub_key)
}
const ECC_PRIVKEY_SIZE: usize = 32;
const ECC_PUBKEY_SIZE: usize = 32;
const DERIVE_PATH_SALT: &[u8; 15] = b"PATH DERIVATION";
#[allow(clippy::needless_pass_by_value)]
fn apply_derive(path: &[u8], src: MLADecryptionPrivateKey) -> [u8; 32] {
const SEED_LEN: usize = 32;
let (dprf_salt, _hkdf) = Hkdf::<Sha512>::extract(None, src.private_key_ecc.as_bytes());
let hkdf: Hkdf<Sha512> = Hkdf::new(
Some(&dprf_salt),
src.private_key_seed_ml.to_d_z_64().as_ref(),
);
let mut seed = [0u8; SEED_LEN];
hkdf.expand_multi_info(&[DERIVE_PATH_SALT, path], &mut seed)
.expect("Unexpected error while derivating along the path");
seed
}
fn derive_one_path_component(
path: &[u8],
privkey: MLADecryptionPrivateKey,
) -> (MLAPrivateKey, MLAPublicKey) {
let seed = apply_derive(path, privkey);
generate_mla_keypair_from_seed(seed)
}
pub fn derive_keypair_from_path<'a>(
path_components: impl Iterator<Item = &'a [u8]>,
src: MLAPrivateKey,
) -> Option<(MLAPrivateKey, MLAPublicKey)> {
let initial_keypair = (src, None);
let (privkey, opt_pubkey) = path_components.fold(initial_keypair, |keypair, path| {
let (previous_private_decryption_key, _) = keypair.0.get_private_keys();
let (privkey, pubkey) = derive_one_path_component(path, previous_private_decryption_key);
(privkey, Some(pubkey))
});
opt_pubkey.map(|pubkey| (privkey, pubkey))
}
#[cfg(test)]
mod tests {
use crate::crypto::hybrid::generate_keypair_from_seed;
use std::io::{Seek, SeekFrom};
use super::*;
use ml_kem::kem::{Decapsulate, Encapsulate};
use x25519_dalek::PublicKey;
fn check_key_pair(pub_key: &MLAEncryptionPublicKey, priv_key: &MLADecryptionPrivateKey) {
const MLKEM_1024_PUBKEY_SIZE: usize = 1568;
let computed_ecc_pubkey = PublicKey::from(&priv_key.private_key_ecc);
assert_eq!(pub_key.public_key_ecc.as_bytes().len(), ECC_PUBKEY_SIZE);
assert_eq!(priv_key.private_key_ecc.as_bytes().len(), ECC_PRIVKEY_SIZE);
assert_eq!(
pub_key.public_key_ecc.as_bytes(),
computed_ecc_pubkey.as_bytes()
);
assert_eq!(
pub_key.public_key_ml.as_bytes().len(),
MLKEM_1024_PUBKEY_SIZE
);
let mut rng = rand::rngs::OsRng {};
let (encap, key) = pub_key.public_key_ml.encapsulate(&mut rng).unwrap();
let key_decap = priv_key
.private_key_seed_ml
.to_privkey()
.decapsulate(&encap)
.unwrap();
assert_eq!(key, key_decap);
}
#[test]
fn keypair_serialize_deserialize_and_check() {
let (priv_key, pub_key) = generate_mla_keypair().unwrap();
let mut cursor = Cursor::new(Vec::new());
priv_key.serialize_private_key(&mut cursor).unwrap();
cursor.seek(SeekFrom::Start(0)).unwrap();
let priv_key = MLAPrivateKey::deserialize_private_key(&mut cursor).unwrap();
cursor.seek(SeekFrom::Start(0)).unwrap();
pub_key.serialize_public_key(&mut cursor).unwrap();
cursor.seek(SeekFrom::Start(0)).unwrap();
let pub_key = MLAPublicKey::deserialize_public_key(&mut cursor).unwrap();
check_key_pair(
pub_key.get_encryption_public_key(),
priv_key.get_decryption_private_key(),
);
}
#[test]
fn keypair_without_newlines() {
let (priv_key, pub_key) = generate_mla_keypair().unwrap();
let mut ser1 = Vec::new();
priv_key.serialize_private_key(&mut ser1).unwrap();
let mut replaced = ser1.clone();
for b in &mut replaced {
if *b == b'\r' || *b == b'\n' {
*b = b'_';
}
}
let priv_key2 = MLAPrivateKey::deserialize_private_key(replaced.as_slice()).unwrap();
let mut ser2 = Vec::new();
priv_key2.serialize_private_key(&mut ser2).unwrap();
assert_eq!(ser1, ser2);
let mut ser1 = Vec::new();
pub_key.serialize_public_key(&mut ser1).unwrap();
let mut replaced = ser1.clone();
for b in &mut replaced {
if *b == b'\r' || *b == b'\n' {
*b = b'_';
}
}
let pub_key2 = MLAPublicKey::deserialize_public_key(replaced.as_slice()).unwrap();
let mut ser2 = Vec::new();
pub_key2.serialize_public_key(&mut ser2).unwrap();
assert_eq!(ser1, ser2);
check_key_pair(
pub_key2.get_encryption_public_key(),
priv_key2.get_decryption_private_key(),
);
}
#[test]
fn keypair_deterministic() {
let (priv1, pub1) = generate_mla_keypair_from_seed([0; 32]);
let (priv2, pub2) = generate_mla_keypair_from_seed([0; 32]);
let mut priv1s = Vec::new();
let mut pub1s = Vec::new();
let mut priv2s = Vec::new();
let mut pub2s = Vec::new();
priv1.serialize_private_key(&mut priv1s).unwrap();
pub1.serialize_public_key(&mut pub1s).unwrap();
priv2.serialize_private_key(&mut priv2s).unwrap();
pub2.serialize_public_key(&mut pub2s).unwrap();
assert_eq!(priv1s, priv2s);
assert_eq!(pub1s, pub2s);
let (priv3, pub3) = generate_mla_keypair_from_seed([1; 32]);
let mut priv3s = Vec::new();
let mut pub3s = Vec::new();
priv3.serialize_private_key(&mut priv3s).unwrap();
pub3.serialize_public_key(&mut pub3s).unwrap();
assert_ne!(priv1s, priv3s);
assert_ne!(pub1s, pub3s);
}
#[test]
fn check_apply_derive() {
use std::collections::HashSet;
use x25519_dalek::StaticSecret;
const SEED_LEN: usize = 32;
let (privkey, _pubkey) = generate_keypair_from_seed([0; 32]);
let path = b"test";
let seed = apply_derive(path, privkey);
assert_ne!(seed, [0u8; SEED_LEN]);
let (privkey, _pubkey) = generate_keypair_from_seed([0; 32]);
let path = b"test2";
let seed_2 = apply_derive(path, privkey);
assert_ne!(seed, seed_2);
let mut priv_keys = vec![];
for i in 0..1 {
for j in 0..1 {
priv_keys.push(MLADecryptionPrivateKey {
private_key_ecc: StaticSecret::from([i; 32]),
private_key_seed_ml: MLKEMSeed::from_d_z_64([j; 64]),
});
}
}
let seeds: Vec<_> = priv_keys
.into_iter()
.map(|pkey| apply_derive(b"test", pkey))
.collect();
assert_eq!((seeds.iter().collect::<HashSet<_>>()).len(), seeds.len());
}
#[test]
fn check_derive_paths() {
let ser_priv: &'static [u8] = include_bytes!("../../../samples/test_mlakey.mlapriv");
let ser_derived_priv: &'static [u8] =
include_bytes!("../../../samples/test_mlakey_derived.mlapriv");
let secret =
crate::crypto::mlakey::MLAPrivateKey::deserialize_private_key(ser_priv).unwrap();
let path = [b"pathcomponent1".as_slice(), b"pathcomponent2".as_slice()];
let (privkey, _) = derive_keypair_from_path(path.into_iter(), secret).unwrap();
let mut computed_ser_derived_priv = Vec::new();
privkey
.serialize_private_key(&mut computed_ser_derived_priv)
.unwrap();
assert_eq!(computed_ser_derived_priv.as_slice(), ser_derived_priv);
}
#[test]
fn test_deserialization_errors() {
use std::io::Cursor;
let missing_header = b"WRONG HEADER bWxhLWtlbS1wdWJsaWMtMTIzNDU2\n";
let mut cursor = Cursor::new(&missing_header[..]);
let result = MLAPrivateKey::deserialize_private_key(&mut cursor);
assert!(matches!(result, Err(Error::DeserializationError)));
let corrupted_base64 = b"MLA PRIVATE DECRYPTION KEY !!@@##\n";
let mut cursor = Cursor::new(&corrupted_base64[..]);
let result = MLAPrivateKey::deserialize_private_key(&mut cursor);
assert!(matches!(result, Err(Error::DeserializationError)));
let bad_method_id = b"MLA PRIVATE DECRYPTION KEY QUFB\n";
let mut cursor = Cursor::new(&bad_method_id[..]);
let result = MLAPrivateKey::deserialize_private_key(&mut cursor);
assert!(matches!(result, Err(Error::DeserializationError)));
let truncated_data = b"MLA PRIVATE DECRYPTION KEY bWxh\n";
let mut cursor = Cursor::new(&truncated_data[..]);
let result = MLAPrivateKey::deserialize_private_key(&mut cursor);
assert!(matches!(result, Err(Error::DeserializationError)));
}
#[test]
fn test_deserialize_private_key_wrong_line_count_too_few() {
let input = b"MLA PRIVATE KEY FILE V1\r\n\
MLA PRIVATE DECRYPTION KEY bWxhYmFzZTY0Cg==\r\n\
TODO\r\n\
END OF MLA PRIVATE KEY FILE\r\n";
let mut cursor = Cursor::new(&input[..]);
let result = MLAPrivateKey::deserialize_private_key(&mut cursor);
assert!(matches!(result, Err(Error::DeserializationError)));
}
#[test]
fn test_deserialize_private_key_wrong_line_count_too_many() {
let input = b"MLA PRIVATE KEY FILE V1\r\n\
MLA PRIVATE DECRYPTION KEY bWxhYmFzZTY0Cg==\r\n\
TODO\r\n\
opts_line\r\n\
END OF MLA PRIVATE KEY FILE\r\n\
EXTRA LINE\r\n";
let mut cursor = Cursor::new(&input[..]);
let result = MLAPrivateKey::deserialize_private_key(&mut cursor);
assert!(matches!(result, Err(Error::DeserializationError)));
}
#[test]
fn test_deserialize_private_key_missing_crlf() {
let input = b"MLA PRIVATE KEY FILE V1\n\
MLA PRIVATE DECRYPTION KEY bWxhYmFzZTY0Cg==\n\
TODO\n\
opts_line\n\
END OF MLA PRIVATE KEY FILE\n";
let mut cursor = Cursor::new(&input[..]);
let result = MLAPrivateKey::deserialize_private_key(&mut cursor);
assert!(matches!(result, Err(Error::DeserializationError)));
}
#[test]
fn test_deserialize_private_key_header_footer_case_sensitive() {
let input = b"mla private key file v1\r\n\
MLA PRIVATE DECRYPTION KEY bWxhYmFzZTY0Cg==\r\n\
TODO\r\n\
opts_line\r\n\
END OF MLA PRIVATE KEY FILE\r\n";
let mut cursor = Cursor::new(&input[..]);
let result = MLAPrivateKey::deserialize_private_key(&mut cursor);
assert!(matches!(result, Err(Error::DeserializationError)));
let input = b"MLA PRIVATE KEY FILE V1\r\n\
MLA PRIVATE DECRYPTION KEY bWxhYmFzZTY0Cg==\r\n\
TODO\r\n\
opts_line\r\n\
end of mla private key file\r\n";
let mut cursor = Cursor::new(&input[..]);
let result = MLAPrivateKey::deserialize_private_key(&mut cursor);
assert!(matches!(result, Err(Error::DeserializationError)));
}
}