use ml_kem::array::sizes::U64;
use ml_kem::array::Array;
use ml_kem::{DecapsulationKey, KeyExport, MlKem768};
use sha3::digest::{ExtendableOutput, Update, XofReader};
use sha3::Shake256;
use thiserror::Error;
use x25519_dalek::x25519;
use zeroize::{Zeroize, ZeroizeOnDrop};
use crate::kdf::hkdf_sha256;
pub const SEED_LENGTH: usize = 32;
const DERIVED_LENGTH: usize = 32;
pub const INFO_ED25519: &[u8] = b"cardano-poe-ed25519-v1";
pub const INFO_X25519: &[u8] = b"cardano-poe-x25519-v1";
pub const INFO_MLKEM768X25519: &[u8] = b"cardano-poe-mlkem768x25519-v1";
pub const MLKEM768X25519_PUBLIC_KEY_LENGTH: usize = 1216;
const MLKEM_EK_LENGTH: usize = 1184;
pub const XWING_EXPANDED_SEED_LENGTH: usize = 96;
#[derive(Debug, Clone, PartialEq, Eq, Error)]
pub enum SeedDeriveError {
#[error("seed must be exactly 32 bytes, got {0}")]
InvalidSeedLength(usize),
}
#[derive(Clone, Zeroize, ZeroizeOnDrop)]
pub struct DerivedEd25519KeyPair {
pub secret_key: [u8; 32],
#[zeroize(skip)]
pub public_key: [u8; 32],
}
#[derive(Clone, Zeroize, ZeroizeOnDrop)]
pub struct DerivedX25519KeyPair {
pub secret_key: [u8; 32],
#[zeroize(skip)]
pub public_key: [u8; 32],
}
#[derive(Clone, Zeroize, ZeroizeOnDrop)]
pub struct DerivedMlKem768X25519KeyPair {
pub secret_seed: [u8; 32],
#[zeroize(skip)]
pub public_key: [u8; MLKEM768X25519_PUBLIC_KEY_LENGTH],
}
fn checked_seed(seed: &[u8]) -> Result<[u8; SEED_LENGTH], SeedDeriveError> {
seed.try_into()
.map_err(|_| SeedDeriveError::InvalidSeedLength(seed.len()))
}
fn derive_key_material(seed: &[u8; SEED_LENGTH], info: &[u8]) -> [u8; DERIVED_LENGTH] {
let okm = hkdf_sha256(seed, &[], info, DERIVED_LENGTH)
.expect("32-byte HKDF output is well within the RFC 5869 maximum");
okm.try_into()
.expect("hkdf_sha256 returns exactly the requested length")
}
pub fn derive_ed25519_keypair(seed: &[u8]) -> Result<DerivedEd25519KeyPair, SeedDeriveError> {
let seed = checked_seed(seed)?;
let secret_key = derive_key_material(&seed, INFO_ED25519);
let signing = ed25519_dalek::SigningKey::from_bytes(&secret_key);
let public_key = signing.verifying_key().to_bytes();
Ok(DerivedEd25519KeyPair {
secret_key,
public_key,
})
}
pub fn derive_x25519_keypair(seed: &[u8]) -> Result<DerivedX25519KeyPair, SeedDeriveError> {
let seed = checked_seed(seed)?;
let secret_key = derive_key_material(&seed, INFO_X25519);
let public_key = x25519_public_key(&secret_key);
Ok(DerivedX25519KeyPair {
secret_key,
public_key,
})
}
pub fn derive_mlkem768x25519_keypair(
seed: &[u8],
) -> Result<DerivedMlKem768X25519KeyPair, SeedDeriveError> {
let seed = checked_seed(seed)?;
let xwing_seed = derive_key_material(&seed, INFO_MLKEM768X25519);
let public_key = xwing_keygen(&xwing_seed);
Ok(DerivedMlKem768X25519KeyPair {
secret_seed: xwing_seed,
public_key,
})
}
fn x25519_public_key(secret_scalar: &[u8; 32]) -> [u8; 32] {
x25519(*secret_scalar, x25519_dalek::X25519_BASEPOINT_BYTES)
}
#[must_use]
pub fn xwing_keygen(seed: &[u8; SEED_LENGTH]) -> [u8; MLKEM768X25519_PUBLIC_KEY_LENGTH] {
let mut expanded = expand_xwing_seed(seed);
let mlkem_seed: Array<u8, U64> = Array::try_from(&expanded[0..64])
.expect("the 96-byte expansion always yields a 64-byte ML-KEM seed prefix");
let dk = DecapsulationKey::<MlKem768>::from_seed(mlkem_seed);
let ek_bytes = dk.encapsulation_key().to_bytes();
let mut x_scalar = [0u8; 32];
x_scalar.copy_from_slice(&expanded[64..96]);
let pk_x25519 = x25519_public_key(&x_scalar);
let mut public_key = [0u8; MLKEM768X25519_PUBLIC_KEY_LENGTH];
public_key[..MLKEM_EK_LENGTH].copy_from_slice(ek_bytes.as_slice());
public_key[MLKEM_EK_LENGTH..].copy_from_slice(&pk_x25519);
expanded.zeroize();
x_scalar.zeroize();
public_key
}
#[must_use]
pub fn expand_xwing_seed(seed: &[u8; SEED_LENGTH]) -> [u8; XWING_EXPANDED_SEED_LENGTH] {
let mut hasher = Shake256::default();
hasher.update(seed);
let mut reader = hasher.finalize_xof();
let mut expanded = [0u8; XWING_EXPANDED_SEED_LENGTH];
reader.read(&mut expanded);
expanded
}
#[derive(Clone, Zeroize, ZeroizeOnDrop)]
pub struct SeedSigner {
secret_key: [u8; 32],
#[zeroize(skip)]
public_key: [u8; 32],
}
impl SeedSigner {
#[must_use]
pub fn public_key(&self) -> [u8; 32] {
self.public_key
}
}
impl std::fmt::Debug for SeedSigner {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SeedSigner")
.field("public_key", &hex_public(&self.public_key))
.finish_non_exhaustive()
}
}
fn hex_public(bytes: &[u8; 32]) -> String {
let mut s = String::with_capacity(64);
for b in bytes {
s.push_str(&format!("{b:02x}"));
}
s
}
impl crate::client::Signer for SeedSigner {
fn signer_pubkey(&self) -> Vec<u8> {
self.public_key.to_vec()
}
fn sign(&self, sig_structure_bytes: &[u8]) -> Result<Vec<u8>, crate::client::SignerError> {
use ed25519_dalek::{Signer as _, SigningKey};
let signing = SigningKey::from_bytes(&self.secret_key);
Ok(signing.sign(sig_structure_bytes).to_bytes().to_vec())
}
}
pub fn signer_from_seed(seed: &[u8]) -> Result<SeedSigner, SeedDeriveError> {
let pair = derive_ed25519_keypair(seed)?;
Ok(SeedSigner {
secret_key: pair.secret_key,
public_key: pair.public_key,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn rejects_wrong_seed_length() {
assert_eq!(
derive_ed25519_keypair(&[0u8; 31]).err(),
Some(SeedDeriveError::InvalidSeedLength(31)),
);
assert_eq!(
derive_x25519_keypair(&[0u8; 33]).err(),
Some(SeedDeriveError::InvalidSeedLength(33)),
);
assert_eq!(
derive_mlkem768x25519_keypair(&[]).err(),
Some(SeedDeriveError::InvalidSeedLength(0)),
);
}
#[test]
fn info_labels_have_their_protocol_lengths() {
assert_eq!(INFO_ED25519.len(), 22);
assert_eq!(INFO_X25519.len(), 21);
assert_eq!(INFO_MLKEM768X25519.len(), 29);
}
#[test]
fn x25519_secret_is_stored_unclamped() {
let pair = derive_x25519_keypair(&[7u8; 32]).unwrap();
let raw = derive_key_material(&[7u8; 32], INFO_X25519);
assert_eq!(pair.secret_key, raw);
}
#[test]
fn xwing_secret_is_the_root_seed() {
let xwing_seed = derive_key_material(&[3u8; 32], INFO_MLKEM768X25519);
let pair = derive_mlkem768x25519_keypair(&[3u8; 32]).unwrap();
assert_eq!(pair.secret_seed, xwing_seed);
assert_eq!(pair.public_key.len(), MLKEM768X25519_PUBLIC_KEY_LENGTH);
}
#[test]
fn derivation_is_deterministic() {
let a = derive_ed25519_keypair(&[1u8; 32]).unwrap();
let b = derive_ed25519_keypair(&[1u8; 32]).unwrap();
assert_eq!(a.secret_key, b.secret_key);
assert_eq!(a.public_key, b.public_key);
}
#[test]
fn seed_signer_pubkey_matches_derivation_and_signs() {
use crate::client::Signer;
let seed = [9u8; 32];
let signer = signer_from_seed(&seed).unwrap();
let derived = derive_ed25519_keypair(&seed).unwrap();
assert_eq!(signer.signer_pubkey(), derived.public_key.to_vec());
let sig = signer.sign(b"cip-309 sig structure").unwrap();
assert_eq!(sig.len(), 64);
let sig2 = signer.sign(b"cip-309 sig structure").unwrap();
assert_eq!(sig, sig2);
}
#[test]
fn seed_signer_rejects_wrong_seed_length() {
assert_eq!(
signer_from_seed(&[0u8; 31]).err(),
Some(SeedDeriveError::InvalidSeedLength(31))
);
}
}