1use std::collections::HashSet;
7use std::fmt;
8use std::io;
9
10use anubis_core::{
11 format::{FileKey, Stanza, FILE_KEY_BYTES},
12 primitives::{aead_decrypt, aead_encrypt, hkdf},
13 secrecy::{ExposeSecret, SecretString},
14};
15use bech32::{ToBase32, Variant};
16use oqs::kem::{Algorithm, Kem, SharedSecret};
17use zeroize::{Zeroize, Zeroizing};
18
19use crate::{
20 error::{DecryptError, EncryptError},
21 util::parse_bech32,
22};
23
24const SECRET_KEY_PREFIX: &str = "ANUBIS-MLKEM-1024-SECRET";
25const PUBLIC_KEY_PREFIX: &str = "anubis1mlkem1";
26
27pub const MLKEM1024_RECIPIENT_TAG: &str = "MLKEM-1024";
29const MLKEM1024_RECIPIENT_KEY_LABEL: &[u8] = b"anubis-encryption.org/v1/MLKEM-1024";
30
31pub const MLKEM1024_PUBLIC_KEY_BYTES: usize = 1568;
33pub const MLKEM1024_SECRET_KEY_BYTES: usize = 3168;
35pub const MLKEM1024_CIPHERTEXT_BYTES: usize = 1568;
37const ENCRYPTED_FILE_KEY_BYTES: usize = FILE_KEY_BYTES + 16;
38
39fn mlkem() -> Kem {
40 oqs::init();
41 Kem::new(Algorithm::MlKem1024).expect("ML-KEM-1024 algorithm available")
42}
43
44fn derive_wrap_key(
45 shared_secret: &SharedSecret,
46 public_key: &[u8],
47 ciphertext: &[u8],
48) -> Zeroizing<[u8; 32]> {
49 let mut salt = Vec::with_capacity(public_key.len() + ciphertext.len());
50 salt.extend_from_slice(public_key);
51 salt.extend_from_slice(ciphertext);
52 let key = hkdf(&salt, MLKEM1024_RECIPIENT_KEY_LABEL, shared_secret.as_ref());
53 salt.zeroize();
54 Zeroizing::new(key)
55}
56
57fn invalid_data(msg: &str) -> io::Error {
58 io::Error::new(io::ErrorKind::InvalidData, msg)
59}
60
61#[derive(Clone)]
65pub struct Identity {
66 secret_key: Zeroizing<Vec<u8>>,
67 public_key: Vec<u8>,
68}
69
70impl Identity {
71 pub fn generate() -> Self {
73 let kem = mlkem();
74 let (pk, sk) = kem.keypair().expect("ML-KEM keypair");
75 Self {
76 secret_key: Zeroizing::new(sk.as_ref().to_vec()),
77 public_key: pk.as_ref().to_vec(),
78 }
79 }
80
81 pub fn to_string(&self) -> SecretString {
85 let mut material =
86 Vec::with_capacity(MLKEM1024_SECRET_KEY_BYTES + MLKEM1024_PUBLIC_KEY_BYTES);
87 material.extend_from_slice(self.secret_key.as_ref());
88 material.extend_from_slice(&self.public_key);
89
90 let encoded = bech32::encode(SECRET_KEY_PREFIX, material.to_base32(), Variant::Bech32)
91 .expect("valid HRP");
92
93 material.zeroize();
94
95 SecretString::from(encoded.to_uppercase())
96 }
97
98 pub fn to_public(&self) -> Recipient {
100 Recipient {
101 public_key: self.public_key.clone(),
102 }
103 }
104
105 pub(crate) fn decapsulate(&self, ct: &[u8; 1568]) -> Result<[u8; 32], DecryptError> {
107 let kem = mlkem();
108 let sk = kem
109 .secret_key_from_bytes(self.secret_key.as_ref())
110 .ok_or(DecryptError::InvalidHeader)?;
111 let ciphertext = kem
112 .ciphertext_from_bytes(ct)
113 .ok_or(DecryptError::InvalidHeader)?;
114 let shared_secret = kem
115 .decapsulate(sk, ciphertext)
116 .map_err(|_| DecryptError::DecryptionFailed)?;
117
118 let mut ss_bytes = [0u8; 32];
119 ss_bytes.copy_from_slice(shared_secret.as_ref());
120 Ok(ss_bytes)
121 }
122}
123
124impl fmt::Display for Identity {
125 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
126 write!(f, "{}", self.to_string().expose_secret())
127 }
128}
129
130impl std::str::FromStr for Identity {
131 type Err = &'static str;
132
133 fn from_str(s: &str) -> Result<Self, Self::Err> {
134 let (hrp, bytes) = parse_bech32(s).ok_or("invalid Bech32 encoding")?;
135 if !hrp.eq_ignore_ascii_case(SECRET_KEY_PREFIX) {
136 return Err("incorrect HRP");
137 }
138 if bytes.len() != MLKEM1024_SECRET_KEY_BYTES + MLKEM1024_PUBLIC_KEY_BYTES {
139 return Err("incorrect identity length");
140 }
141
142 let secret_key = Zeroizing::new(bytes[..MLKEM1024_SECRET_KEY_BYTES].to_vec());
143 let public_key = bytes[MLKEM1024_SECRET_KEY_BYTES..].to_vec();
144
145 Ok(Self {
146 secret_key,
147 public_key,
148 })
149 }
150}
151
152impl crate::Identity for Identity {
153 fn unwrap_stanza(&self, stanza: &Stanza) -> Option<Result<FileKey, DecryptError>> {
154 if stanza.tag != MLKEM1024_RECIPIENT_TAG {
155 return None;
156 }
157
158 if stanza.body.len() != MLKEM1024_CIPHERTEXT_BYTES + ENCRYPTED_FILE_KEY_BYTES {
159 return Some(Err(DecryptError::InvalidHeader));
160 }
161
162 let (ct_bytes, encrypted_file_key) = stanza.body.split_at(MLKEM1024_CIPHERTEXT_BYTES);
163
164 let kem = mlkem();
165 let sk = match kem.secret_key_from_bytes(self.secret_key.as_ref()) {
166 Some(sk) => sk,
167 None => return Some(Err(DecryptError::InvalidHeader)),
168 };
169 let ciphertext = match kem.ciphertext_from_bytes(ct_bytes) {
170 Some(ct) => ct,
171 None => return Some(Err(DecryptError::InvalidHeader)),
172 };
173 let shared_secret = match kem.decapsulate(sk, ciphertext) {
174 Ok(ss) => ss,
175 Err(_) => return Some(Err(DecryptError::InvalidHeader)),
176 };
177
178 let wrap_key = derive_wrap_key(&shared_secret, &self.public_key, ct_bytes);
179
180 aead_decrypt(&wrap_key, FILE_KEY_BYTES, encrypted_file_key)
181 .ok()
182 .map(|mut plaintext| {
183 Ok(FileKey::init_with_mut(|file_key| {
184 file_key.copy_from_slice(&plaintext);
185 plaintext.zeroize();
186 }))
187 })
188 }
189}
190
191#[derive(Clone, PartialEq, Eq, Hash)]
195pub struct Recipient {
196 public_key: Vec<u8>,
197}
198
199impl Recipient {
200 fn ensure_length(bytes: &[u8]) -> Result<(), &'static str> {
201 if bytes.len() == MLKEM1024_PUBLIC_KEY_BYTES {
202 Ok(())
203 } else {
204 Err("incorrect pubkey length")
205 }
206 }
207
208 pub(crate) fn encapsulate<R: rand::Rng + rand::CryptoRng>(
210 &self,
211 _rng: &mut R,
212 ) -> Result<([u8; 1568], Zeroizing<Vec<u8>>), EncryptError> {
213 let kem = mlkem();
214 let pk = kem
215 .public_key_from_bytes(&self.public_key)
216 .ok_or_else(|| EncryptError::Io(invalid_data("invalid ML-KEM public key")))?;
217 let (ciphertext, shared_secret) = kem.encapsulate(pk).map_err(|_| {
218 EncryptError::Io(invalid_data("failed to encapsulate ML-KEM shared secret"))
219 })?;
220
221 let mut ct_bytes = [0u8; 1568];
222 ct_bytes.copy_from_slice(ciphertext.as_ref());
223 Ok((ct_bytes, Zeroizing::new(shared_secret.as_ref().to_vec())))
224 }
225}
226
227impl fmt::Display for Recipient {
228 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
229 write!(
230 f,
231 "{}",
232 bech32::encode(
233 PUBLIC_KEY_PREFIX,
234 self.public_key.to_base32(),
235 Variant::Bech32
236 )
237 .expect("valid HRP")
238 )
239 }
240}
241
242impl fmt::Debug for Recipient {
243 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
244 write!(f, "{}", self)
245 }
246}
247
248impl std::str::FromStr for Recipient {
249 type Err = &'static str;
250
251 fn from_str(s: &str) -> Result<Self, Self::Err> {
252 let (hrp, bytes) = parse_bech32(s).ok_or("invalid Bech32 encoding")?;
253 if !hrp.eq_ignore_ascii_case(PUBLIC_KEY_PREFIX) {
254 return Err("incorrect HRP");
255 }
256 Self::ensure_length(&bytes)?;
257 Ok(Self { public_key: bytes })
258 }
259}
260
261impl crate::Recipient for Recipient {
262 fn wrap_file_key(
263 &self,
264 file_key: &FileKey,
265 ) -> Result<(Vec<Stanza>, HashSet<String>), EncryptError> {
266 let kem = mlkem();
267
268 let pk = kem
269 .public_key_from_bytes(&self.public_key)
270 .ok_or_else(|| EncryptError::Io(invalid_data("invalid ML-KEM public key")))?;
271 let (ciphertext, shared_secret) = kem.encapsulate(pk).map_err(|_| {
272 EncryptError::Io(invalid_data("failed to encapsulate ML-KEM shared secret"))
273 })?;
274
275 let wrap_key = derive_wrap_key(&shared_secret, &self.public_key, ciphertext.as_ref());
276 let encrypted_file_key = aead_encrypt(&wrap_key, file_key.expose_secret());
277
278 let mut body = Vec::with_capacity(MLKEM1024_CIPHERTEXT_BYTES + ENCRYPTED_FILE_KEY_BYTES);
279 body.extend_from_slice(ciphertext.as_ref());
280 body.extend_from_slice(&encrypted_file_key);
281
282 let mut labels = HashSet::new();
283 labels.insert("postquantum".to_string());
284 labels.insert("nist-level-5".to_string());
285
286 Ok((
287 vec![Stanza {
288 tag: MLKEM1024_RECIPIENT_TAG.to_owned(),
289 args: vec![],
290 body,
291 }],
292 labels,
293 ))
294 }
295}
296
297#[cfg(test)]
298mod tests {
299 use anubis_core::{format::FileKey, secrecy::ExposeSecret};
300
301 use super::{Identity, Recipient};
302 use crate::{Identity as _, Recipient as _};
303
304 #[test]
305 fn round_trip() {
306 let identity = Identity::generate();
307 let recipient = identity.to_public();
308 let file_key = FileKey::new(Box::new([42; 16]));
309
310 let (stanzas, labels) = recipient.wrap_file_key(&file_key).unwrap();
311 assert!(labels.contains("postquantum"));
312 assert!(labels.contains("nist-level-5"));
313 assert_eq!(stanzas.len(), 1);
314
315 let recovered = identity.unwrap_stanzas(&stanzas).unwrap().unwrap();
316 assert_eq!(recovered.expose_secret(), file_key.expose_secret());
317 }
318
319 #[test]
320 fn bech32_round_trip() {
321 let identity = Identity::generate();
322 let encoded = identity.to_string();
323 let reparsed: Identity = encoded.expose_secret().parse().unwrap();
324 assert_eq!(identity.public_key, reparsed.public_key);
325
326 let recipient = identity.to_public();
327 let encoded_recipient = recipient.to_string();
328 let reparsed_recipient: Recipient = encoded_recipient.parse().unwrap();
329 assert_eq!(recipient, reparsed_recipient);
330 }
331}