anubis_age/
protocol.rs

1//! Encryption and decryption routines for age.
2
3use anubis_core::format::is_arbitrary_string;
4use rand::{rngs::OsRng, RngCore};
5
6use std::collections::HashSet;
7use std::io::{self, BufRead, Read, Write};
8use std::iter;
9
10use crate::{
11    error::{DecryptError, EncryptError},
12    format::{Header, HeaderV1},
13    keys::{mac_key, new_file_key, v1_payload_key},
14    primitives::stream::{PayloadKey, Stream, StreamReader, StreamWriter},
15    Identity, Recipient,
16};
17
18#[cfg(feature = "async")]
19use futures::io::{AsyncBufRead, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
20
21pub(crate) struct Nonce([u8; 16]);
22
23impl AsRef<[u8]> for Nonce {
24    fn as_ref(&self) -> &[u8] {
25        &self.0
26    }
27}
28
29impl Nonce {
30    fn random() -> Self {
31        let mut nonce = [0; 16];
32        OsRng.fill_bytes(&mut nonce);
33        Nonce(nonce)
34    }
35
36    fn read<R: Read>(input: &mut R) -> io::Result<Self> {
37        let mut nonce = [0; 16];
38        input.read_exact(&mut nonce)?;
39        Ok(Nonce(nonce))
40    }
41
42    #[cfg(feature = "async")]
43    #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
44    async fn read_async<R: AsyncRead + Unpin>(input: &mut R) -> io::Result<Self> {
45        let mut nonce = [0; 16];
46        input.read_exact(&mut nonce).await?;
47        Ok(Nonce(nonce))
48    }
49}
50
51/// Encryptor for creating an age file.
52pub struct Encryptor {
53    header: Header,
54    nonce: Nonce,
55    payload_key: PayloadKey,
56}
57
58impl Encryptor {
59    /// Constructs an `Encryptor` that will create an age file encrypted to a list of
60    /// recipients.
61    pub fn with_recipients<'a>(
62        recipients: impl Iterator<Item = &'a dyn Recipient>,
63    ) -> Result<Self, EncryptError> {
64        let file_key = new_file_key();
65
66        let recipients = {
67            let mut labels: Option<HashSet<String>> = None;
68
69            let mut stanzas = vec![];
70            let mut have_recipients = false;
71            for recipient in recipients {
72                have_recipients = true;
73                let (mut r_stanzas, r_labels) = recipient.wrap_file_key(&file_key)?;
74
75                if let Some(expected) = labels.as_ref() {
76                    if *expected != r_labels {
77                        return Err(EncryptError::IncompatibleRecipients {
78                            l_labels: expected.clone(),
79                            r_labels,
80                        });
81                    }
82                } else if r_labels.iter().all(is_arbitrary_string) {
83                    labels = Some(r_labels.clone());
84                } else {
85                    return Err(EncryptError::InvalidRecipientLabels(r_labels));
86                }
87
88                stanzas.append(&mut r_stanzas);
89            }
90            if !have_recipients {
91                return Err(EncryptError::MissingRecipients);
92            }
93            stanzas
94        };
95
96        let header = HeaderV1::new(recipients, mac_key(&file_key))?;
97        let nonce = Nonce::random();
98        let payload_key = v1_payload_key(&file_key, &header, &nonce).expect("MAC is correct");
99
100        Ok(Self {
101            header: Header::V1(header),
102            nonce,
103            payload_key,
104        })
105    }
106
107    /// Creates a wrapper around a writer that will encrypt its input.
108    ///
109    /// Returns errors from the underlying writer while writing the header.
110    ///
111    /// You **MUST** call [`StreamWriter::finish`] when you are done writing, in order to
112    /// finish the encryption process. Failing to call [`StreamWriter::finish`] will
113    /// result in a truncated file that will fail to decrypt.
114    pub fn wrap_output<W: Write>(self, mut output: W) -> io::Result<StreamWriter<W>> {
115        let Self {
116            header,
117            nonce,
118            payload_key,
119        } = self;
120        header.write(&mut output)?;
121        output.write_all(nonce.as_ref())?;
122        Ok(Stream::encrypt(payload_key, output))
123    }
124
125    /// Creates a wrapper around a writer that will encrypt its input.
126    ///
127    /// Returns errors from the underlying writer while writing the header.
128    ///
129    /// You **MUST** call [`AsyncWrite::poll_close`] when you are done writing, in order
130    /// to finish the encryption process. Failing to call [`AsyncWrite::poll_close`]
131    /// will result in a truncated file that will fail to decrypt.
132    #[cfg(feature = "async")]
133    #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
134    pub async fn wrap_async_output<W: AsyncWrite + Unpin>(
135        self,
136        mut output: W,
137    ) -> io::Result<StreamWriter<W>> {
138        let Self {
139            header,
140            nonce,
141            payload_key,
142        } = self;
143        header.write_async(&mut output).await?;
144        output.write_all(nonce.as_ref()).await?;
145        Ok(Stream::encrypt_async(payload_key, output))
146    }
147}
148
149/// Decryptor for an age file.
150pub struct Decryptor<R> {
151    /// The age file.
152    input: R,
153    /// The age file's header.
154    header: Header,
155    /// The age file's AEAD nonce
156    nonce: Nonce,
157}
158
159impl<R> Decryptor<R> {
160    fn from_v1_header(input: R, header: HeaderV1, nonce: Nonce) -> Result<Self, DecryptError> {
161        // Enforce structural requirements on the v1 header.
162        if header.is_valid() {
163            Ok(Self {
164                input,
165                header: Header::V1(header),
166                nonce,
167            })
168        } else {
169            Err(DecryptError::InvalidHeader)
170        }
171    }
172
173    fn obtain_payload_key<'a>(
174        &self,
175        mut identities: impl Iterator<Item = &'a dyn Identity>,
176    ) -> Result<PayloadKey, DecryptError> {
177        match &self.header {
178            Header::V1(header) => identities
179                .find_map(|key| key.unwrap_stanzas(&header.recipients))
180                .unwrap_or(Err(DecryptError::NoMatchingKeys))
181                .and_then(|file_key| v1_payload_key(&file_key, header, &self.nonce)),
182            Header::Unknown(_) => unreachable!(),
183        }
184    }
185}
186
187impl<R: Read> Decryptor<R> {
188    /// Attempts to create a decryptor for an age file.
189    ///
190    /// Returns an error if the input does not contain a valid age file.
191    ///
192    /// # Performance
193    ///
194    /// This constructor will work with any type implementing [`io::Read`], and uses a
195    /// slower parser and internal buffering to ensure no overreading occurs. Consider
196    /// using [`Decryptor::new_buffered`] for types implementing `std::io::BufRead`, which
197    /// includes `&[u8]` slices.
198    pub fn new(mut input: R) -> Result<Self, DecryptError> {
199        let header = Header::read(&mut input)?;
200
201        match header {
202            Header::V1(v1_header) => {
203                let nonce = Nonce::read(&mut input)?;
204                Decryptor::from_v1_header(input, v1_header, nonce)
205            }
206            Header::Unknown(_) => Err(DecryptError::UnknownFormat),
207        }
208    }
209
210    /// Attempts to decrypt the age file.
211    ///
212    /// If successful, returns a reader that will provide the plaintext.
213    pub fn decrypt<'a>(
214        self,
215        identities: impl Iterator<Item = &'a dyn Identity>,
216    ) -> Result<StreamReader<R>, DecryptError> {
217        self.obtain_payload_key(identities)
218            .map(|payload_key| Stream::decrypt(payload_key, self.input))
219    }
220}
221
222impl<R: BufRead> Decryptor<R> {
223    /// Attempts to create a decryptor for an age file.
224    ///
225    /// Returns an error if the input does not contain a valid age file.
226    ///
227    /// # Performance
228    ///
229    /// This constructor is more performant than [`Decryptor::new`] for types implementing
230    /// [`io::BufRead`], which includes `&[u8]` slices.
231    pub fn new_buffered(mut input: R) -> Result<Self, DecryptError> {
232        let header = Header::read_buffered(&mut input)?;
233
234        match header {
235            Header::V1(v1_header) => {
236                let nonce = Nonce::read(&mut input)?;
237                Decryptor::from_v1_header(input, v1_header, nonce)
238            }
239            Header::Unknown(_) => Err(DecryptError::UnknownFormat),
240        }
241    }
242}
243
244#[cfg(feature = "async")]
245#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
246impl<R: AsyncRead + Unpin> Decryptor<R> {
247    /// Attempts to create a decryptor for an age file.
248    ///
249    /// Returns an error if the input does not contain a valid age file.
250    ///
251    /// # Performance
252    ///
253    /// This constructor will work with any type implementing [`AsyncRead`], and uses a
254    /// slower parser and internal buffering to ensure no overreading occurs. Consider
255    /// using [`Decryptor::new_async_buffered`] for types implementing [`AsyncBufRead`],
256    /// which includes `&[u8]` slices.
257    pub async fn new_async(mut input: R) -> Result<Self, DecryptError> {
258        let header = Header::read_async(&mut input).await?;
259
260        match header {
261            Header::V1(v1_header) => {
262                let nonce = Nonce::read_async(&mut input).await?;
263                Decryptor::from_v1_header(input, v1_header, nonce)
264            }
265            Header::Unknown(_) => Err(DecryptError::UnknownFormat),
266        }
267    }
268
269    /// Attempts to decrypt the age file.
270    ///
271    /// If successful, returns a reader that will provide the plaintext.
272    pub fn decrypt_async<'a>(
273        self,
274        identities: impl Iterator<Item = &'a dyn Identity>,
275    ) -> Result<StreamReader<R>, DecryptError> {
276        self.obtain_payload_key(identities)
277            .map(|payload_key| Stream::decrypt_async(payload_key, self.input))
278    }
279}
280
281#[cfg(feature = "async")]
282#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
283impl<R: AsyncBufRead + Unpin> Decryptor<R> {
284    /// Attempts to create a decryptor for an age file.
285    ///
286    /// Returns an error if the input does not contain a valid age file.
287    ///
288    /// # Performance
289    ///
290    /// This constructor is more performant than [`Decryptor::new_async`] for types
291    /// implementing [`AsyncBufRead`], which includes `&[u8]` slices.
292    pub async fn new_async_buffered(mut input: R) -> Result<Self, DecryptError> {
293        let header = Header::read_async_buffered(&mut input).await?;
294
295        match header {
296            Header::V1(v1_header) => {
297                let nonce = Nonce::read_async(&mut input).await?;
298                Decryptor::from_v1_header(input, v1_header, nonce)
299            }
300            Header::Unknown(_) => Err(DecryptError::UnknownFormat),
301        }
302    }
303}
304
305#[cfg(test)]
306mod tests {
307    use std::collections::HashSet;
308    use std::io::{BufReader, Read, Write};
309    use std::iter;
310
311    use anubis_core::secrecy::SecretString;
312
313    use super::{Decryptor, Encryptor};
314    use crate::{pqc::mlkem, EncryptError, Identity, Recipient};
315
316    #[cfg(feature = "async")]
317    use futures::{
318        io::{AsyncRead, AsyncWrite},
319        pin_mut,
320        task::Poll,
321        Future,
322    };
323    #[cfg(feature = "async")]
324    use futures_test::task::noop_context;
325
326    fn recipient_round_trip<'a>(
327        recipients: impl Iterator<Item = &'a dyn Recipient>,
328        identities: impl Iterator<Item = &'a dyn Identity>,
329    ) {
330        let test_msg = b"This is a test message. For testing.";
331
332        let mut encrypted = vec![];
333        let e = Encryptor::with_recipients(recipients).unwrap();
334        {
335            let mut w = e.wrap_output(&mut encrypted).unwrap();
336            w.write_all(test_msg).unwrap();
337            w.finish().unwrap();
338        }
339
340        let d = Decryptor::new_buffered(&encrypted[..]).unwrap();
341        let mut r = d.decrypt(identities).unwrap();
342        let mut decrypted = vec![];
343        r.read_to_end(&mut decrypted).unwrap();
344
345        assert_eq!(&decrypted[..], &test_msg[..]);
346    }
347
348    #[cfg(feature = "async")]
349    fn recipient_async_round_trip<'a>(
350        recipients: impl Iterator<Item = &'a dyn Recipient>,
351        identities: impl Iterator<Item = &'a dyn Identity>,
352    ) {
353        let test_msg = b"This is a test message. For testing.";
354        let mut cx = noop_context();
355
356        let mut encrypted = vec![];
357        let e = Encryptor::with_recipients(recipients).unwrap();
358        {
359            let w = {
360                let f = e.wrap_async_output(&mut encrypted);
361                pin_mut!(f);
362
363                loop {
364                    match f.as_mut().poll(&mut cx) {
365                        Poll::Ready(Ok(w)) => break w,
366                        Poll::Ready(Err(e)) => panic!("Unexpected error: {}", e),
367                        Poll::Pending => panic!("Unexpected Pending"),
368                    }
369                }
370            };
371            pin_mut!(w);
372
373            let mut tmp = &test_msg[..];
374            loop {
375                match w.as_mut().poll_write(&mut cx, tmp) {
376                    Poll::Ready(Ok(0)) => break,
377                    Poll::Ready(Ok(written)) => tmp = &tmp[written..],
378                    Poll::Ready(Err(e)) => panic!("Unexpected error: {}", e),
379                    Poll::Pending => panic!("Unexpected Pending"),
380                }
381            }
382            loop {
383                match w.as_mut().poll_close(&mut cx) {
384                    Poll::Ready(Ok(())) => break,
385                    Poll::Ready(Err(e)) => panic!("Unexpected error: {}", e),
386                    Poll::Pending => panic!("Unexpected Pending"),
387                }
388            }
389        }
390
391        let d = {
392            let f = Decryptor::new_async(&encrypted[..]);
393            pin_mut!(f);
394
395            loop {
396                match f.as_mut().poll(&mut cx) {
397                    Poll::Ready(Ok(w)) => break w,
398                    Poll::Ready(Err(e)) => panic!("Unexpected error: {}", e),
399                    Poll::Pending => panic!("Unexpected Pending"),
400                }
401            }
402        };
403
404        let decrypted = {
405            let mut buf = vec![];
406            let r = d.decrypt_async(identities).unwrap();
407            pin_mut!(r);
408
409            let mut tmp = [0; 4096];
410            loop {
411                match r.as_mut().poll_read(&mut cx, &mut tmp) {
412                    Poll::Ready(Ok(0)) => break buf,
413                    Poll::Ready(Ok(read)) => buf.extend_from_slice(&tmp[..read]),
414                    Poll::Ready(Err(e)) => panic!("Unexpected error: {}", e),
415                    Poll::Pending => panic!("Unexpected Pending"),
416                }
417            }
418        };
419
420        assert_eq!(&decrypted[..], &test_msg[..]);
421    }
422
423    #[test]
424    fn mlkem_round_trip() {
425        let identity = mlkem::Identity::generate();
426        let recipient = identity.to_public();
427        recipient_round_trip(iter::once(&recipient as _), iter::once(&identity as _));
428    }
429
430    #[cfg(feature = "async")]
431    #[test]
432    fn mlkem_async_round_trip() {
433        let identity = mlkem::Identity::generate();
434        let recipient = identity.to_public();
435        recipient_async_round_trip(iter::once(&recipient as _), iter::once(&identity as _));
436    }
437
438    struct IncompatibleRecipient(mlkem::Recipient);
439
440    impl Recipient for IncompatibleRecipient {
441        fn wrap_file_key(
442            &self,
443            file_key: &anubis_core::format::FileKey,
444        ) -> Result<(Vec<anubis_core::format::Stanza>, HashSet<String>), EncryptError> {
445            self.0.wrap_file_key(file_key).map(|(stanzas, mut labels)| {
446                labels.insert("incompatible".into());
447                (stanzas, labels)
448            })
449        }
450    }
451
452    #[test]
453    fn incompatible_recipients() {
454        let recipient = mlkem::Identity::generate().to_public();
455        let incompatible = IncompatibleRecipient(recipient.clone());
456
457        let recipients = [&recipient as &dyn Recipient, &incompatible as _];
458
459        assert!(matches!(
460            Encryptor::with_recipients(recipients.into_iter()),
461            Err(EncryptError::IncompatibleRecipients { .. }),
462        ));
463    }
464}