#![allow(clippy::upper_case_acronyms)]
#[cfg(feature = "x509")]
mod cert;
#[cfg(feature = "x509")]
pub use cert::*;
pub mod xwing;
use crate::pem;
use base64::Engine;
use base64::engine::general_purpose::STANDARD as BASE64;
use hpke::rand_core::SeedableRng;
use hpke::{Deserializable, HpkeError, Kem, Serializable};
use pkcs8::PrivateKeyInfo;
use serde::{Deserialize, Deserializer, Serialize, Serializer, de};
use sha2::Digest;
use spki::der::asn1::BitStringRef;
use spki::der::{AnyRef, Decode, Encode};
use spki::{AlgorithmIdentifier, ObjectIdentifier, SubjectPublicKeyInfo};
use std::error::Error;
pub const OID: ObjectIdentifier = ObjectIdentifier::new_unwrap("1.3.6.1.4.1.62253.25722");
type KEM = xwing::Kem;
type AEAD = hpke::aead::ChaCha20Poly1305;
type KDF = hpke::kdf::HkdfSha256;
pub const SECRET_KEY_SIZE: usize = 32;
pub const PUBLIC_KEY_SIZE: usize = 1216;
pub const ENCAP_KEY_SIZE: usize = 1120;
pub const FINGERPRINT_SIZE: usize = 32;
#[derive(Clone, PartialEq, Eq)]
pub struct SecretKey {
inner: <KEM as Kem>::PrivateKey,
}
impl SecretKey {
pub fn generate() -> SecretKey {
let mut rng = rand::rng();
let (key, _) = KEM::gen_keypair(&mut rng);
Self { inner: key }
}
pub fn from_bytes(bin: &[u8; SECRET_KEY_SIZE]) -> Self {
let inner = <KEM as Kem>::PrivateKey::from_bytes(bin).unwrap();
Self { inner }
}
pub fn from_der(der: &[u8]) -> Result<Self, Box<dyn Error>> {
let info = PrivateKeyInfo::from_der(der)?;
if info.encoded_len()?.try_into() != Ok(der.len()) {
return Err("trailing data in private key".into());
}
if info.algorithm.oid != OID {
return Err("not an X-Wing private key".into());
}
let bytes: [u8; 32] = info.private_key.try_into()?;
Ok(SecretKey::from_bytes(&bytes))
}
pub fn from_pem(pem_str: &str) -> Result<Self, Box<dyn Error>> {
let (kind, data) = pem::decode(pem_str.as_bytes())?;
if kind != "PRIVATE KEY" {
return Err(format!("invalid PEM tag {}", kind).into());
}
Self::from_der(&data)
}
pub fn to_bytes(&self) -> [u8; SECRET_KEY_SIZE] {
self.inner.to_bytes().into()
}
pub fn to_der(&self) -> Vec<u8> {
let bytes = self.inner.to_bytes();
let alg = pkcs8::AlgorithmIdentifierRef {
oid: OID,
parameters: None::<AnyRef>,
};
let info = PrivateKeyInfo {
algorithm: alg,
private_key: &bytes,
public_key: None,
};
info.to_der().unwrap()
}
pub fn to_pem(&self) -> String {
pem::encode("PRIVATE KEY", &self.to_der())
}
pub fn public_key(&self) -> PublicKey {
PublicKey {
inner: KEM::sk_to_pk(&self.inner),
}
}
pub fn fingerprint(&self) -> Fingerprint {
self.public_key().fingerprint()
}
pub fn open(
&self,
session_key: &[u8; ENCAP_KEY_SIZE],
msg_to_open: &[u8],
msg_to_auth: &[u8],
domain: &[u8],
) -> Result<Vec<u8>, HpkeError> {
let session = <KEM as Kem>::EncappedKey::from_bytes(session_key)?;
let mut ctx = hpke::setup_receiver::<AEAD, KDF, KEM>(
&hpke::OpModeR::Base,
&self.inner,
&session,
domain,
)?;
ctx.open(msg_to_open, msg_to_auth)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PublicKey {
inner: <KEM as Kem>::PublicKey,
}
impl PublicKey {
pub fn from_bytes(bin: &[u8; PUBLIC_KEY_SIZE]) -> Result<Self, Box<dyn Error>> {
validate_mlkem768_encapsulation_key(&bin[..1184])?;
let inner = <KEM as Kem>::PublicKey::from_bytes(bin)?;
Ok(Self { inner })
}
pub fn from_der(der: &[u8]) -> Result<Self, Box<dyn Error>> {
let info: SubjectPublicKeyInfo<AlgorithmIdentifier<AnyRef>, BitStringRef> =
SubjectPublicKeyInfo::from_der(der)?;
if info.encoded_len()?.try_into() != Ok(der.len()) {
return Err("trailing data in public key".into());
}
if info.algorithm.oid != OID {
return Err("not an X-Wing public key".into());
}
let key = info.subject_public_key.as_bytes().unwrap();
let bytes: [u8; 1216] = key.try_into()?;
PublicKey::from_bytes(&bytes)
}
pub fn from_pem(pem_str: &str) -> Result<Self, Box<dyn Error>> {
let (kind, data) = pem::decode(pem_str.as_bytes())?;
if kind != "PUBLIC KEY" {
return Err(format!("invalid PEM tag {}", kind).into());
}
Self::from_der(&data)
}
pub fn to_bytes(&self) -> [u8; PUBLIC_KEY_SIZE] {
let mut result = [0u8; 1216];
result.copy_from_slice(&self.inner.to_bytes());
result
}
pub fn to_der(&self) -> Vec<u8> {
let bytes = self.inner.to_bytes();
let alg = AlgorithmIdentifier::<AnyRef> {
oid: OID,
parameters: None::<AnyRef>,
};
let info = SubjectPublicKeyInfo::<AnyRef, BitStringRef> {
algorithm: alg,
subject_public_key: BitStringRef::from_bytes(&bytes).unwrap(),
};
info.to_der().unwrap()
}
pub fn to_pem(&self) -> String {
pem::encode("PUBLIC KEY", &self.to_der())
}
pub fn fingerprint(&self) -> Fingerprint {
let mut hasher = sha2::Sha256::new();
hasher.update(self.to_bytes());
Fingerprint(hasher.finalize().into())
}
pub fn seal(
&self,
msg_to_seal: &[u8],
msg_to_auth: &[u8],
domain: &[u8],
) -> Result<([u8; ENCAP_KEY_SIZE], Vec<u8>), HpkeError> {
let mut seed = [0u8; 32];
getrandom::fill(&mut seed).expect("Failed to get random seed");
let mut rng = rand_chacha::ChaCha20Rng::from_seed(seed);
let (key, mut ctx) = hpke::setup_sender::<AEAD, KDF, KEM, _>(
&hpke::OpModeS::Base,
&self.inner,
domain,
&mut rng,
)?;
let enc = ctx.seal(msg_to_seal, msg_to_auth)?;
let mut encap_key = [0u8; 1120];
encap_key.copy_from_slice(&key.to_bytes());
Ok((encap_key, enc))
}
}
impl Serialize for PublicKey {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
serializer.serialize_str(&BASE64.encode(self.to_bytes()))
}
}
impl<'de> Deserialize<'de> for PublicKey {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let s = String::deserialize(deserializer)?;
let bytes = BASE64.decode(&s).map_err(de::Error::custom)?;
let arr: [u8; PUBLIC_KEY_SIZE] = bytes
.try_into()
.map_err(|_| de::Error::custom("invalid public key length"))?;
PublicKey::from_bytes(&arr).map_err(de::Error::custom)
}
}
#[cfg(feature = "cbor")]
impl crate::cbor::Encode for PublicKey {
fn encode_cbor(&self) -> Vec<u8> {
self.to_bytes().encode_cbor()
}
}
#[cfg(feature = "cbor")]
impl crate::cbor::Decode for PublicKey {
fn decode_cbor(data: &[u8]) -> Result<Self, crate::cbor::Error> {
let bytes = <[u8; PUBLIC_KEY_SIZE]>::decode_cbor(data)?;
Self::from_bytes(&bytes).map_err(|e| crate::cbor::Error::DecodeFailed(e.to_string()))
}
fn decode_cbor_notrail(
decoder: &mut crate::cbor::Decoder<'_>,
) -> Result<Self, crate::cbor::Error> {
let bytes = decoder.decode_bytes_fixed::<PUBLIC_KEY_SIZE>()?;
Self::from_bytes(&bytes).map_err(|e| crate::cbor::Error::DecodeFailed(e.to_string()))
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Fingerprint([u8; FINGERPRINT_SIZE]);
impl Fingerprint {
pub fn from_bytes(bytes: &[u8; FINGERPRINT_SIZE]) -> Self {
Self(*bytes)
}
pub fn to_bytes(&self) -> [u8; FINGERPRINT_SIZE] {
self.0
}
}
impl Serialize for Fingerprint {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
serializer.serialize_str(&BASE64.encode(self.to_bytes()))
}
}
impl<'de> Deserialize<'de> for Fingerprint {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let s = String::deserialize(deserializer)?;
let bytes = BASE64.decode(&s).map_err(de::Error::custom)?;
let arr: [u8; FINGERPRINT_SIZE] = bytes
.try_into()
.map_err(|_| de::Error::custom("invalid fingerprint length"))?;
Ok(Fingerprint::from_bytes(&arr))
}
}
#[cfg(feature = "cbor")]
impl crate::cbor::Encode for Fingerprint {
fn encode_cbor(&self) -> Vec<u8> {
self.to_bytes().encode_cbor()
}
}
#[cfg(feature = "cbor")]
impl crate::cbor::Decode for Fingerprint {
fn decode_cbor(data: &[u8]) -> Result<Self, crate::cbor::Error> {
let bytes = <[u8; FINGERPRINT_SIZE]>::decode_cbor(data)?;
Ok(Self::from_bytes(&bytes))
}
fn decode_cbor_notrail(
decoder: &mut crate::cbor::Decoder<'_>,
) -> Result<Self, crate::cbor::Error> {
let bytes = decoder.decode_bytes_fixed::<FINGERPRINT_SIZE>()?;
Ok(Self::from_bytes(&bytes))
}
}
fn validate_mlkem768_encapsulation_key(key: &[u8]) -> Result<(), Box<dyn Error>> {
const Q: u16 = 3329;
let coeff_bytes = &key[..1152];
for chunk in coeff_bytes.chunks(3) {
let coeff1 = u16::from(chunk[0]) | ((u16::from(chunk[1]) & 0x0F) << 8);
let coeff2 = (u16::from(chunk[1]) >> 4) | (u16::from(chunk[2]) << 4);
if coeff1 >= Q {
return Err(format!("invalid ML-KEM coefficient: {} >= {}", coeff1, Q).into());
}
if coeff2 >= Q {
return Err(format!("invalid ML-KEM coefficient: {} >= {}", coeff2, Q).into());
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_secretkey_bytes_roundtrip() {
let key = SecretKey::generate();
let bytes = key.to_bytes();
let parsed = SecretKey::from_bytes(&bytes);
assert_eq!(key.to_bytes(), parsed.to_bytes());
}
#[test]
fn test_publickey_bytes_roundtrip() {
let key = SecretKey::generate().public_key();
let bytes = key.to_bytes();
let parsed = PublicKey::from_bytes(&bytes).unwrap();
assert_eq!(key.to_bytes(), parsed.to_bytes());
}
#[test]
fn test_secretkey_der_roundtrip() {
let key = SecretKey::generate();
let der = key.to_der();
let parsed = SecretKey::from_der(&der).unwrap();
assert_eq!(key.to_bytes(), parsed.to_bytes());
}
#[test]
fn test_secretkey_pem_roundtrip() {
let key = SecretKey::generate();
let pem = key.to_pem();
let parsed = SecretKey::from_pem(&pem).unwrap();
assert_eq!(key.to_bytes(), parsed.to_bytes());
}
#[test]
fn test_publickey_der_roundtrip() {
let key = SecretKey::generate().public_key();
let der = key.to_der();
let parsed = PublicKey::from_der(&der).unwrap();
assert_eq!(key.to_bytes(), parsed.to_bytes());
}
#[test]
fn test_publickey_pem_roundtrip() {
let key = SecretKey::generate().public_key();
let pem = key.to_pem();
let parsed = PublicKey::from_pem(&pem).unwrap();
assert_eq!(key.to_bytes(), parsed.to_bytes());
}
#[test]
fn test_seal_open() {
let secret = SecretKey::generate();
let public = secret.public_key();
struct TestCase<'a> {
seal_msg: &'a [u8],
auth_msg: &'a [u8],
}
let tests = [
TestCase {
seal_msg: &[],
auth_msg: b"message to authenticate",
},
TestCase {
seal_msg: b"message to encrypt",
auth_msg: &[],
},
TestCase {
seal_msg: b"message to encrypt",
auth_msg: b"message to authenticate",
},
];
for tt in &tests {
let (sess_key, seal_msg) = public
.seal(tt.seal_msg, tt.auth_msg, b"test")
.unwrap_or_else(|e| panic!("failed to seal message: {}", e));
let cleartext = secret
.open(&sess_key, &seal_msg, tt.auth_msg, b"test")
.unwrap_or_else(|e| panic!("failed to open message: {}", e));
assert_eq!(cleartext, tt.seal_msg, "unexpected cleartext");
}
}
}