1#![allow(clippy::upper_case_acronyms)]
14
15pub mod cert;
16pub mod xwing;
17
18use crate::pem;
19use base64::Engine;
20use base64::engine::general_purpose::STANDARD as BASE64;
21use hpke::rand_core::SeedableRng;
22use hpke::{Deserializable, HpkeError, Kem, Serializable};
23use pkcs8::PrivateKeyInfo;
24use serde::{Deserialize, Deserializer, Serialize, Serializer, de};
25use sha2::Digest;
26use spki::der::asn1::BitStringRef;
27use spki::der::{AnyRef, Decode, Encode};
28use spki::{AlgorithmIdentifier, ObjectIdentifier, SubjectPublicKeyInfo};
29use std::error::Error;
30
31type KEM = xwing::Kem;
40type AEAD = hpke::aead::ChaCha20Poly1305;
41type KDF = hpke::kdf::HkdfSha256;
42
43pub const SECRET_KEY_SIZE: usize = 32;
45
46pub const PUBLIC_KEY_SIZE: usize = 1216;
48
49pub const ENCAP_KEY_SIZE: usize = 1120;
51
52pub const FINGERPRINT_SIZE: usize = 32;
54
55#[derive(Clone, PartialEq, Eq)]
57pub struct SecretKey {
58 inner: <KEM as Kem>::PrivateKey,
59}
60
61impl SecretKey {
62 pub fn generate() -> SecretKey {
64 let mut rng = rand::rng();
65
66 let (key, _) = KEM::gen_keypair(&mut rng);
67 Self { inner: key }
68 }
69
70 pub fn from_bytes(bin: &[u8; SECRET_KEY_SIZE]) -> Self {
72 let inner = <KEM as Kem>::PrivateKey::from_bytes(bin).unwrap();
73 Self { inner }
74 }
75
76 pub fn from_der(der: &[u8]) -> Result<Self, Box<dyn Error>> {
78 let info = PrivateKeyInfo::from_der(der)?;
80
81 if info.encoded_len()?.try_into() != Ok(der.len()) {
83 return Err("trailing data in private key".into());
84 }
85 if info.algorithm.oid.to_string() != "1.3.6.1.4.1.62253.25722" {
87 return Err("not an X-Wing private key".into());
88 }
89 let bytes: [u8; 32] = info.private_key.try_into()?;
90 Ok(SecretKey::from_bytes(&bytes))
91 }
92
93 pub fn from_pem(pem_str: &str) -> Result<Self, Box<dyn Error>> {
95 let (kind, data) = pem::decode(pem_str.as_bytes())?;
97 if kind != "PRIVATE KEY" {
98 return Err(format!("invalid PEM tag {}", kind).into());
99 }
100 Self::from_der(&data)
102 }
103
104 pub fn to_bytes(&self) -> [u8; SECRET_KEY_SIZE] {
106 self.inner.to_bytes().into()
107 }
108
109 pub fn to_der(&self) -> Vec<u8> {
111 let bytes = self.inner.to_bytes();
112
113 let alg = pkcs8::AlgorithmIdentifierRef {
115 oid: ObjectIdentifier::new_unwrap("1.3.6.1.4.1.62253.25722"),
116 parameters: None::<AnyRef>,
117 };
118 let info = PrivateKeyInfo {
120 algorithm: alg,
121 private_key: &bytes,
122 public_key: None,
123 };
124 info.to_der().unwrap()
125 }
126
127 pub fn to_pem(&self) -> String {
129 pem::encode("PRIVATE KEY", &self.to_der())
130 }
131
132 pub fn public_key(&self) -> PublicKey {
134 PublicKey {
135 inner: KEM::sk_to_pk(&self.inner),
136 }
137 }
138
139 pub fn fingerprint(&self) -> Fingerprint {
142 self.public_key().fingerprint()
143 }
144
145 pub fn open(
153 &self,
154 session_key: &[u8; ENCAP_KEY_SIZE],
155 msg_to_open: &[u8],
156 msg_to_auth: &[u8],
157 domain: &[u8],
158 ) -> Result<Vec<u8>, HpkeError> {
159 let session = <KEM as Kem>::EncappedKey::from_bytes(session_key)?;
161
162 let mut ctx = hpke::setup_receiver::<AEAD, KDF, KEM>(
164 &hpke::OpModeR::Base,
165 &self.inner,
166 &session,
167 domain,
168 )?;
169 ctx.open(msg_to_open, msg_to_auth)
171 }
172}
173
174#[derive(Debug, Clone, PartialEq, Eq)]
176pub struct PublicKey {
177 inner: <KEM as Kem>::PublicKey,
178}
179
180impl PublicKey {
181 pub fn from_bytes(bin: &[u8; PUBLIC_KEY_SIZE]) -> Result<Self, Box<dyn Error>> {
186 validate_mlkem768_encapsulation_key(&bin[..1184])?;
190
191 let inner = <KEM as Kem>::PublicKey::from_bytes(bin)?;
192 Ok(Self { inner })
193 }
194
195 pub fn from_der(der: &[u8]) -> Result<Self, Box<dyn Error>> {
197 let info: SubjectPublicKeyInfo<AlgorithmIdentifier<AnyRef>, BitStringRef> =
199 SubjectPublicKeyInfo::from_der(der)?;
200
201 if info.encoded_len()?.try_into() != Ok(der.len()) {
203 return Err("trailing data in public key".into());
204 }
205 if info.algorithm.oid.to_string() != "1.3.6.1.4.1.62253.25722" {
207 return Err("not an X-Wing public key".into());
208 }
209 let key = info.subject_public_key.as_bytes().unwrap();
210
211 let bytes: [u8; 1216] = key.try_into()?;
213 PublicKey::from_bytes(&bytes)
214 }
215
216 pub fn from_pem(pem_str: &str) -> Result<Self, Box<dyn Error>> {
218 let (kind, data) = pem::decode(pem_str.as_bytes())?;
220 if kind != "PUBLIC KEY" {
221 return Err(format!("invalid PEM tag {}", kind).into());
222 }
223 Self::from_der(&data)
225 }
226
227 pub fn to_bytes(&self) -> [u8; PUBLIC_KEY_SIZE] {
229 let mut result = [0u8; 1216];
230 result.copy_from_slice(&self.inner.to_bytes());
231 result
232 }
233
234 pub fn to_der(&self) -> Vec<u8> {
236 let bytes = self.inner.to_bytes();
237
238 let alg = AlgorithmIdentifier::<AnyRef> {
240 oid: ObjectIdentifier::new_unwrap("1.3.6.1.4.1.62253.25722"),
241 parameters: None::<AnyRef>,
242 };
243 let info = SubjectPublicKeyInfo::<AnyRef, BitStringRef> {
245 algorithm: alg,
246 subject_public_key: BitStringRef::from_bytes(&bytes).unwrap(),
247 };
248 info.to_der().unwrap()
249 }
250
251 pub fn to_pem(&self) -> String {
253 pem::encode("PUBLIC KEY", &self.to_der())
254 }
255
256 pub fn fingerprint(&self) -> Fingerprint {
259 let mut hasher = sha2::Sha256::new();
260 hasher.update(self.to_bytes());
261 Fingerprint(hasher.finalize().into())
262 }
263
264 pub fn seal(
276 &self,
277 msg_to_seal: &[u8],
278 msg_to_auth: &[u8],
279 domain: &[u8],
280 ) -> Result<([u8; ENCAP_KEY_SIZE], Vec<u8>), HpkeError> {
281 let mut seed = [0u8; 32];
283 getrandom::fill(&mut seed).expect("Failed to get random seed");
284 let mut rng = rand_chacha::ChaCha20Rng::from_seed(seed);
285
286 let (key, mut ctx) = hpke::setup_sender::<AEAD, KDF, KEM, _>(
288 &hpke::OpModeS::Base,
289 &self.inner,
290 domain,
291 &mut rng,
292 )?;
293
294 let enc = ctx.seal(msg_to_seal, msg_to_auth)?;
296
297 let mut encap_key = [0u8; 1120];
298 encap_key.copy_from_slice(&key.to_bytes());
299 Ok((encap_key, enc))
300 }
301}
302
303impl Serialize for PublicKey {
304 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
305 serializer.serialize_str(&BASE64.encode(self.to_bytes()))
306 }
307}
308
309impl<'de> Deserialize<'de> for PublicKey {
310 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
311 let s = String::deserialize(deserializer)?;
312 let bytes = BASE64.decode(&s).map_err(de::Error::custom)?;
313 let arr: [u8; PUBLIC_KEY_SIZE] = bytes
314 .try_into()
315 .map_err(|_| de::Error::custom("invalid public key length"))?;
316 PublicKey::from_bytes(&arr).map_err(de::Error::custom)
317 }
318}
319
320#[cfg(feature = "cbor")]
321impl crate::cbor::Encode for PublicKey {
322 fn encode_cbor(&self) -> Vec<u8> {
323 self.to_bytes().encode_cbor()
324 }
325}
326
327#[cfg(feature = "cbor")]
328impl crate::cbor::Decode for PublicKey {
329 fn decode_cbor(data: &[u8]) -> Result<Self, crate::cbor::Error> {
330 let bytes = <[u8; PUBLIC_KEY_SIZE]>::decode_cbor(data)?;
331 Self::from_bytes(&bytes).map_err(|e| crate::cbor::Error::DecodeFailed(e.to_string()))
332 }
333
334 fn decode_cbor_notrail(
335 decoder: &mut crate::cbor::Decoder<'_>,
336 ) -> Result<Self, crate::cbor::Error> {
337 let bytes = decoder.decode_bytes_fixed::<PUBLIC_KEY_SIZE>()?;
338 Self::from_bytes(&bytes).map_err(|e| crate::cbor::Error::DecodeFailed(e.to_string()))
339 }
340}
341
342#[derive(Debug, Clone, Copy, PartialEq, Eq)]
344pub struct Fingerprint([u8; FINGERPRINT_SIZE]);
345
346impl Fingerprint {
347 pub fn from_bytes(bytes: &[u8; FINGERPRINT_SIZE]) -> Self {
349 Self(*bytes)
350 }
351
352 pub fn to_bytes(&self) -> [u8; FINGERPRINT_SIZE] {
354 self.0
355 }
356}
357
358impl Serialize for Fingerprint {
359 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
360 serializer.serialize_str(&BASE64.encode(self.to_bytes()))
361 }
362}
363
364impl<'de> Deserialize<'de> for Fingerprint {
365 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
366 let s = String::deserialize(deserializer)?;
367 let bytes = BASE64.decode(&s).map_err(de::Error::custom)?;
368 let arr: [u8; FINGERPRINT_SIZE] = bytes
369 .try_into()
370 .map_err(|_| de::Error::custom("invalid fingerprint length"))?;
371 Ok(Fingerprint::from_bytes(&arr))
372 }
373}
374
375#[cfg(feature = "cbor")]
376impl crate::cbor::Encode for Fingerprint {
377 fn encode_cbor(&self) -> Vec<u8> {
378 self.to_bytes().encode_cbor()
379 }
380}
381
382#[cfg(feature = "cbor")]
383impl crate::cbor::Decode for Fingerprint {
384 fn decode_cbor(data: &[u8]) -> Result<Self, crate::cbor::Error> {
385 let bytes = <[u8; FINGERPRINT_SIZE]>::decode_cbor(data)?;
386 Ok(Self::from_bytes(&bytes))
387 }
388
389 fn decode_cbor_notrail(
390 decoder: &mut crate::cbor::Decoder<'_>,
391 ) -> Result<Self, crate::cbor::Error> {
392 let bytes = decoder.decode_bytes_fixed::<FINGERPRINT_SIZE>()?;
393 Ok(Self::from_bytes(&bytes))
394 }
395}
396
397fn validate_mlkem768_encapsulation_key(key: &[u8]) -> Result<(), Box<dyn Error>> {
403 const Q: u16 = 3329;
404
405 let coeff_bytes = &key[..1152];
408 for chunk in coeff_bytes.chunks(3) {
409 let coeff1 = u16::from(chunk[0]) | ((u16::from(chunk[1]) & 0x0F) << 8);
411 let coeff2 = (u16::from(chunk[1]) >> 4) | (u16::from(chunk[2]) << 4);
412
413 if coeff1 >= Q {
414 return Err(format!("invalid ML-KEM coefficient: {} >= {}", coeff1, Q).into());
415 }
416 if coeff2 >= Q {
417 return Err(format!("invalid ML-KEM coefficient: {} >= {}", coeff2, Q).into());
418 }
419 }
420 Ok(())
421}
422
423#[cfg(test)]
424mod tests {
425 use super::*;
426
427 #[test]
429 fn test_secretkey_bytes_roundtrip() {
430 let key = SecretKey::generate();
431 let bytes = key.to_bytes();
432 let parsed = SecretKey::from_bytes(&bytes);
433 assert_eq!(key.to_bytes(), parsed.to_bytes());
434 }
435
436 #[test]
438 fn test_publickey_bytes_roundtrip() {
439 let key = SecretKey::generate().public_key();
440 let bytes = key.to_bytes();
441 let parsed = PublicKey::from_bytes(&bytes).unwrap();
442 assert_eq!(key.to_bytes(), parsed.to_bytes());
443 }
444
445 #[test]
447 fn test_secretkey_der_roundtrip() {
448 let key = SecretKey::generate();
449 let der = key.to_der();
450 let parsed = SecretKey::from_der(&der).unwrap();
451 assert_eq!(key.to_bytes(), parsed.to_bytes());
452 }
453
454 #[test]
456 fn test_secretkey_pem_roundtrip() {
457 let key = SecretKey::generate();
458 let pem = key.to_pem();
459 let parsed = SecretKey::from_pem(&pem).unwrap();
460 assert_eq!(key.to_bytes(), parsed.to_bytes());
461 }
462
463 #[test]
465 fn test_publickey_der_roundtrip() {
466 let key = SecretKey::generate().public_key();
467 let der = key.to_der();
468 let parsed = PublicKey::from_der(&der).unwrap();
469 assert_eq!(key.to_bytes(), parsed.to_bytes());
470 }
471
472 #[test]
474 fn test_publickey_pem_roundtrip() {
475 let key = SecretKey::generate().public_key();
476 let pem = key.to_pem();
477 let parsed = PublicKey::from_pem(&pem).unwrap();
478 assert_eq!(key.to_bytes(), parsed.to_bytes());
479 }
480
481 #[test]
485 fn test_seal_open() {
486 let secret = SecretKey::generate();
488 let public = secret.public_key();
489
490 struct TestCase<'a> {
492 seal_msg: &'a [u8],
493 auth_msg: &'a [u8],
494 }
495 let tests = [
496 TestCase {
498 seal_msg: &[],
499 auth_msg: b"message to authenticate",
500 },
501 TestCase {
503 seal_msg: b"message to encrypt",
504 auth_msg: &[],
505 },
506 TestCase {
508 seal_msg: b"message to encrypt",
509 auth_msg: b"message to authenticate",
510 },
511 ];
512
513 for tt in &tests {
514 let (sess_key, seal_msg) = public
516 .seal(tt.seal_msg, tt.auth_msg, b"test")
517 .unwrap_or_else(|e| panic!("failed to seal message: {}", e));
518
519 let cleartext = secret
521 .open(&sess_key, &seal_msg, tt.auth_msg, b"test")
522 .unwrap_or_else(|e| panic!("failed to open message: {}", e));
523
524 assert_eq!(cleartext, tt.seal_msg, "unexpected cleartext");
526 }
527 }
528}