use function_name::named;
use ml_dsa as foreign_mldsa_module;
pub(super) const ML_DSA_SEED_SIZE: usize = 32;
pub(super) type MlDsaSeed = [u8; ML_DSA_SEED_SIZE];
pub(super) trait SupportedMlDsaSecretKey: pqcrypto_traits::sign::SecretKey {
type ForeignParamSet: foreign_mldsa_module::MlDsaParams;
type PublicKey;
}
impl SupportedMlDsaSecretKey for pqcrypto_mldsa::mldsa44::SecretKey {
type ForeignParamSet = foreign_mldsa_module::MlDsa44;
type PublicKey = pqcrypto_mldsa::mldsa44::PublicKey;
}
impl SupportedMlDsaSecretKey for pqcrypto_mldsa::mldsa65::SecretKey {
type ForeignParamSet = foreign_mldsa_module::MlDsa65;
type PublicKey = pqcrypto_mldsa::mldsa65::PublicKey;
}
impl SupportedMlDsaSecretKey for pqcrypto_mldsa::mldsa87::SecretKey {
type ForeignParamSet = foreign_mldsa_module::MlDsa87;
type PublicKey = pqcrypto_mldsa::mldsa87::PublicKey;
}
#[named]
pub(super) fn derive_mldsa_public_key<T>(sk: &T) -> Option<T::PublicKey>
where
T: SupportedMlDsaSecretKey,
<T as SupportedMlDsaSecretKey>::PublicKey: pqcrypto_traits::sign::PublicKey,
{
let encoded_sk = <T as pqcrypto_traits::sign::SecretKey>::as_bytes(sk);
let encoded_sk = match encoded_sk.try_into() {
Ok(p) => p,
Err(e) => {
error!(target: log_target!(), "Failed to encode secret key to bytes: {e:?}");
return None;
}
};
let csk = <foreign_mldsa_module::SigningKey<T::ForeignParamSet>>::decode(encoded_sk);
let cpk = csk.verifying_key();
let pk_bytes = cpk.encode();
let pk_bytes = pk_bytes.as_slice();
let res =
<<T as SupportedMlDsaSecretKey>::PublicKey as pqcrypto_traits::sign::PublicKey>::from_bytes(
pk_bytes,
);
match res {
Ok(pk) => Some(pk),
Err(e) => {
error!(target: log_target!(), "Failed to derive the public key from the inner private key: {e:?}");
return None;
}
}
}
#[named]
pub(super) fn derive_mldsa_secret_key_from_seed<T>(seed: &MlDsaSeed) -> Option<T>
where
T: SupportedMlDsaSecretKey,
{
let foreign_key =
<foreign_mldsa_module::SigningKey<T::ForeignParamSet>>::from_seed(seed.into());
let key_bytes = foreign_key.encode();
let res = <T as pqcrypto_traits::sign::SecretKey>::from_bytes(&key_bytes);
match res {
Ok(sk) => Some(sk),
Err(e) => {
error!(target: log_target!(), "Failed to derive the expanded private key from the seed: {e:?}");
return None;
}
}
}
const VALIDATE_PRIVKEY_DECODING_VIA_FOREIGN_MODULE: bool = true;
#[named]
fn foreign_decode_mldsa_secret_key<T>(
bytes: &[u8],
) -> std::thread::Result<
foreign_mldsa_module::SigningKey<<T as SupportedMlDsaSecretKey>::ForeignParamSet>,
>
where
T: SupportedMlDsaSecretKey,
{
use std::panic::{self, catch_unwind, AssertUnwindSafe};
let a = match bytes.try_into() {
Ok(a) => a,
Err(e) => {
error!(target: log_target!(), "Found wrong length when decoding EncodedPrivateKey: {e:?}");
return Err(Box::new(e));
}
};
let prev_hook = panic::take_hook();
panic::set_hook(Box::new(|info| {
trace!(target: log_target!(), "Caught panic: {}", info);
}));
let result = catch_unwind(AssertUnwindSafe(|| {
<foreign_mldsa_module::SigningKey<T::ForeignParamSet>>::decode(a)
}));
panic::set_hook(prev_hook);
result
}
#[named]
pub(super) fn decode_mldsa_secret_key<T>(bytes: &[u8]) -> Option<T>
where
T: SupportedMlDsaSecretKey,
{
match TryInto::<&MlDsaSeed>::try_into(bytes) {
Ok(seed) => {
return derive_mldsa_secret_key_from_seed(seed);
}
Err(_) => (),
}
if VALIDATE_PRIVKEY_DECODING_VIA_FOREIGN_MODULE {
let foreign_result = foreign_decode_mldsa_secret_key::<T>(bytes);
match foreign_result {
Ok(_) => (), Err(e) => {
if let Some(s) = e.downcast_ref::<&str>() {
error!(target: log_target!(), "Failed to decode the EncodedPrivateKey: {s}");
} else if let Some(s) = e.downcast_ref::<String>() {
error!(target: log_target!(), "Failed to decode the EncodedPrivateKey: {s}");
} else {
error!(target: log_target!(), "Failed to decode the EncodedPrivateKey");
}
return None;
}
}
}
T::from_bytes(bytes).ok()
}