1use anubis_core::format::is_arbitrary_string;
4use rand::{rngs::OsRng, RngCore};
5
6use std::collections::HashSet;
7use std::io::{self, BufRead, Read, Write};
8use std::iter;
9
10use crate::{
11 error::{DecryptError, EncryptError},
12 format::{Header, HeaderV1},
13 keys::{mac_key, new_file_key, v1_payload_key},
14 primitives::stream::{PayloadKey, Stream, StreamReader, StreamWriter},
15 Identity, Recipient,
16};
17
18#[cfg(feature = "async")]
19use futures::io::{AsyncBufRead, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
20
21pub(crate) struct Nonce([u8; 16]);
22
23impl AsRef<[u8]> for Nonce {
24 fn as_ref(&self) -> &[u8] {
25 &self.0
26 }
27}
28
29impl Nonce {
30 fn random() -> Self {
31 let mut nonce = [0; 16];
32 OsRng.fill_bytes(&mut nonce);
33 Nonce(nonce)
34 }
35
36 fn read<R: Read>(input: &mut R) -> io::Result<Self> {
37 let mut nonce = [0; 16];
38 input.read_exact(&mut nonce)?;
39 Ok(Nonce(nonce))
40 }
41
42 #[cfg(feature = "async")]
43 #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
44 async fn read_async<R: AsyncRead + Unpin>(input: &mut R) -> io::Result<Self> {
45 let mut nonce = [0; 16];
46 input.read_exact(&mut nonce).await?;
47 Ok(Nonce(nonce))
48 }
49}
50
51pub struct Encryptor {
53 header: Header,
54 nonce: Nonce,
55 payload_key: PayloadKey,
56}
57
58impl Encryptor {
59 pub fn with_recipients<'a>(
62 recipients: impl Iterator<Item = &'a dyn Recipient>,
63 ) -> Result<Self, EncryptError> {
64 let file_key = new_file_key();
65
66 let recipients = {
67 let mut labels: Option<HashSet<String>> = None;
68
69 let mut stanzas = vec![];
70 let mut have_recipients = false;
71 for recipient in recipients {
72 have_recipients = true;
73 let (mut r_stanzas, r_labels) = recipient.wrap_file_key(&file_key)?;
74
75 if let Some(expected) = labels.as_ref() {
76 if *expected != r_labels {
77 return Err(EncryptError::IncompatibleRecipients {
78 l_labels: expected.clone(),
79 r_labels,
80 });
81 }
82 } else if r_labels.iter().all(is_arbitrary_string) {
83 labels = Some(r_labels.clone());
84 } else {
85 return Err(EncryptError::InvalidRecipientLabels(r_labels));
86 }
87
88 stanzas.append(&mut r_stanzas);
89 }
90 if !have_recipients {
91 return Err(EncryptError::MissingRecipients);
92 }
93 stanzas
94 };
95
96 let header = HeaderV1::new(recipients, mac_key(&file_key))?;
97 let nonce = Nonce::random();
98 let payload_key = v1_payload_key(&file_key, &header, &nonce).expect("MAC is correct");
99
100 Ok(Self {
101 header: Header::V1(header),
102 nonce,
103 payload_key,
104 })
105 }
106
107 pub fn wrap_output<W: Write>(self, mut output: W) -> io::Result<StreamWriter<W>> {
115 let Self {
116 header,
117 nonce,
118 payload_key,
119 } = self;
120 header.write(&mut output)?;
121 output.write_all(nonce.as_ref())?;
122 Ok(Stream::encrypt(payload_key, output))
123 }
124
125 #[cfg(feature = "async")]
133 #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
134 pub async fn wrap_async_output<W: AsyncWrite + Unpin>(
135 self,
136 mut output: W,
137 ) -> io::Result<StreamWriter<W>> {
138 let Self {
139 header,
140 nonce,
141 payload_key,
142 } = self;
143 header.write_async(&mut output).await?;
144 output.write_all(nonce.as_ref()).await?;
145 Ok(Stream::encrypt_async(payload_key, output))
146 }
147}
148
149pub struct Decryptor<R> {
151 input: R,
153 header: Header,
155 nonce: Nonce,
157}
158
159impl<R> Decryptor<R> {
160 fn from_v1_header(input: R, header: HeaderV1, nonce: Nonce) -> Result<Self, DecryptError> {
161 if header.is_valid() {
163 Ok(Self {
164 input,
165 header: Header::V1(header),
166 nonce,
167 })
168 } else {
169 Err(DecryptError::InvalidHeader)
170 }
171 }
172
173 fn obtain_payload_key<'a>(
174 &self,
175 mut identities: impl Iterator<Item = &'a dyn Identity>,
176 ) -> Result<PayloadKey, DecryptError> {
177 match &self.header {
178 Header::V1(header) => identities
179 .find_map(|key| key.unwrap_stanzas(&header.recipients))
180 .unwrap_or(Err(DecryptError::NoMatchingKeys))
181 .and_then(|file_key| v1_payload_key(&file_key, header, &self.nonce)),
182 Header::Unknown(_) => unreachable!(),
183 }
184 }
185}
186
187impl<R: Read> Decryptor<R> {
188 pub fn new(mut input: R) -> Result<Self, DecryptError> {
199 let header = Header::read(&mut input)?;
200
201 match header {
202 Header::V1(v1_header) => {
203 let nonce = Nonce::read(&mut input)?;
204 Decryptor::from_v1_header(input, v1_header, nonce)
205 }
206 Header::Unknown(_) => Err(DecryptError::UnknownFormat),
207 }
208 }
209
210 pub fn decrypt<'a>(
214 self,
215 identities: impl Iterator<Item = &'a dyn Identity>,
216 ) -> Result<StreamReader<R>, DecryptError> {
217 self.obtain_payload_key(identities)
218 .map(|payload_key| Stream::decrypt(payload_key, self.input))
219 }
220}
221
222impl<R: BufRead> Decryptor<R> {
223 pub fn new_buffered(mut input: R) -> Result<Self, DecryptError> {
232 let header = Header::read_buffered(&mut input)?;
233
234 match header {
235 Header::V1(v1_header) => {
236 let nonce = Nonce::read(&mut input)?;
237 Decryptor::from_v1_header(input, v1_header, nonce)
238 }
239 Header::Unknown(_) => Err(DecryptError::UnknownFormat),
240 }
241 }
242}
243
244#[cfg(feature = "async")]
245#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
246impl<R: AsyncRead + Unpin> Decryptor<R> {
247 pub async fn new_async(mut input: R) -> Result<Self, DecryptError> {
258 let header = Header::read_async(&mut input).await?;
259
260 match header {
261 Header::V1(v1_header) => {
262 let nonce = Nonce::read_async(&mut input).await?;
263 Decryptor::from_v1_header(input, v1_header, nonce)
264 }
265 Header::Unknown(_) => Err(DecryptError::UnknownFormat),
266 }
267 }
268
269 pub fn decrypt_async<'a>(
273 self,
274 identities: impl Iterator<Item = &'a dyn Identity>,
275 ) -> Result<StreamReader<R>, DecryptError> {
276 self.obtain_payload_key(identities)
277 .map(|payload_key| Stream::decrypt_async(payload_key, self.input))
278 }
279}
280
281#[cfg(feature = "async")]
282#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
283impl<R: AsyncBufRead + Unpin> Decryptor<R> {
284 pub async fn new_async_buffered(mut input: R) -> Result<Self, DecryptError> {
293 let header = Header::read_async_buffered(&mut input).await?;
294
295 match header {
296 Header::V1(v1_header) => {
297 let nonce = Nonce::read_async(&mut input).await?;
298 Decryptor::from_v1_header(input, v1_header, nonce)
299 }
300 Header::Unknown(_) => Err(DecryptError::UnknownFormat),
301 }
302 }
303}
304
305#[cfg(test)]
306mod tests {
307 use std::collections::HashSet;
308 use std::io::{BufReader, Read, Write};
309 use std::iter;
310
311 use anubis_core::secrecy::SecretString;
312
313 use super::{Decryptor, Encryptor};
314 use crate::{pqc::mlkem, EncryptError, Identity, Recipient};
315
316 #[cfg(feature = "async")]
317 use futures::{
318 io::{AsyncRead, AsyncWrite},
319 pin_mut,
320 task::Poll,
321 Future,
322 };
323 #[cfg(feature = "async")]
324 use futures_test::task::noop_context;
325
326 fn recipient_round_trip<'a>(
327 recipients: impl Iterator<Item = &'a dyn Recipient>,
328 identities: impl Iterator<Item = &'a dyn Identity>,
329 ) {
330 let test_msg = b"This is a test message. For testing.";
331
332 let mut encrypted = vec![];
333 let e = Encryptor::with_recipients(recipients).unwrap();
334 {
335 let mut w = e.wrap_output(&mut encrypted).unwrap();
336 w.write_all(test_msg).unwrap();
337 w.finish().unwrap();
338 }
339
340 let d = Decryptor::new_buffered(&encrypted[..]).unwrap();
341 let mut r = d.decrypt(identities).unwrap();
342 let mut decrypted = vec![];
343 r.read_to_end(&mut decrypted).unwrap();
344
345 assert_eq!(&decrypted[..], &test_msg[..]);
346 }
347
348 #[cfg(feature = "async")]
349 fn recipient_async_round_trip<'a>(
350 recipients: impl Iterator<Item = &'a dyn Recipient>,
351 identities: impl Iterator<Item = &'a dyn Identity>,
352 ) {
353 let test_msg = b"This is a test message. For testing.";
354 let mut cx = noop_context();
355
356 let mut encrypted = vec![];
357 let e = Encryptor::with_recipients(recipients).unwrap();
358 {
359 let w = {
360 let f = e.wrap_async_output(&mut encrypted);
361 pin_mut!(f);
362
363 loop {
364 match f.as_mut().poll(&mut cx) {
365 Poll::Ready(Ok(w)) => break w,
366 Poll::Ready(Err(e)) => panic!("Unexpected error: {}", e),
367 Poll::Pending => panic!("Unexpected Pending"),
368 }
369 }
370 };
371 pin_mut!(w);
372
373 let mut tmp = &test_msg[..];
374 loop {
375 match w.as_mut().poll_write(&mut cx, tmp) {
376 Poll::Ready(Ok(0)) => break,
377 Poll::Ready(Ok(written)) => tmp = &tmp[written..],
378 Poll::Ready(Err(e)) => panic!("Unexpected error: {}", e),
379 Poll::Pending => panic!("Unexpected Pending"),
380 }
381 }
382 loop {
383 match w.as_mut().poll_close(&mut cx) {
384 Poll::Ready(Ok(())) => break,
385 Poll::Ready(Err(e)) => panic!("Unexpected error: {}", e),
386 Poll::Pending => panic!("Unexpected Pending"),
387 }
388 }
389 }
390
391 let d = {
392 let f = Decryptor::new_async(&encrypted[..]);
393 pin_mut!(f);
394
395 loop {
396 match f.as_mut().poll(&mut cx) {
397 Poll::Ready(Ok(w)) => break w,
398 Poll::Ready(Err(e)) => panic!("Unexpected error: {}", e),
399 Poll::Pending => panic!("Unexpected Pending"),
400 }
401 }
402 };
403
404 let decrypted = {
405 let mut buf = vec![];
406 let r = d.decrypt_async(identities).unwrap();
407 pin_mut!(r);
408
409 let mut tmp = [0; 4096];
410 loop {
411 match r.as_mut().poll_read(&mut cx, &mut tmp) {
412 Poll::Ready(Ok(0)) => break buf,
413 Poll::Ready(Ok(read)) => buf.extend_from_slice(&tmp[..read]),
414 Poll::Ready(Err(e)) => panic!("Unexpected error: {}", e),
415 Poll::Pending => panic!("Unexpected Pending"),
416 }
417 }
418 };
419
420 assert_eq!(&decrypted[..], &test_msg[..]);
421 }
422
423 #[test]
424 fn mlkem_round_trip() {
425 let identity = mlkem::Identity::generate();
426 let recipient = identity.to_public();
427 recipient_round_trip(iter::once(&recipient as _), iter::once(&identity as _));
428 }
429
430 #[cfg(feature = "async")]
431 #[test]
432 fn mlkem_async_round_trip() {
433 let identity = mlkem::Identity::generate();
434 let recipient = identity.to_public();
435 recipient_async_round_trip(iter::once(&recipient as _), iter::once(&identity as _));
436 }
437
438 struct IncompatibleRecipient(mlkem::Recipient);
439
440 impl Recipient for IncompatibleRecipient {
441 fn wrap_file_key(
442 &self,
443 file_key: &anubis_core::format::FileKey,
444 ) -> Result<(Vec<anubis_core::format::Stanza>, HashSet<String>), EncryptError> {
445 self.0.wrap_file_key(file_key).map(|(stanzas, mut labels)| {
446 labels.insert("incompatible".into());
447 (stanzas, labels)
448 })
449 }
450 }
451
452 #[test]
453 fn incompatible_recipients() {
454 let recipient = mlkem::Identity::generate().to_public();
455 let incompatible = IncompatibleRecipient(recipient.clone());
456
457 let recipients = [&recipient as &dyn Recipient, &incompatible as _];
458
459 assert!(matches!(
460 Encryptor::with_recipients(recipients.into_iter()),
461 Err(EncryptError::IncompatibleRecipients { .. }),
462 ));
463 }
464}