anubis_age/primitives/
stream.rs

1//! I/O helper structs for age file encryption and decryption.
2
3use anubis_core::secrecy::{ExposeSecret, SecretSlice};
4use chacha20poly1305::{
5    aead::{generic_array::GenericArray, Aead, KeyInit, KeySizeUser},
6    ChaCha20Poly1305,
7};
8use pin_project::pin_project;
9use std::cmp;
10use std::io::{self, Read, Seek, SeekFrom, Write};
11use zeroize::Zeroize;
12
13#[cfg(feature = "async")]
14use futures::{
15    io::{AsyncRead, AsyncWrite, Error},
16    ready,
17    task::{Context, Poll},
18};
19#[cfg(feature = "async")]
20use std::pin::Pin;
21
22const CHUNK_SIZE: usize = 64 * 1024;
23const TAG_SIZE: usize = 16;
24const ENCRYPTED_CHUNK_SIZE: usize = CHUNK_SIZE + TAG_SIZE;
25
26pub(crate) struct PayloadKey(
27    pub(crate) GenericArray<u8, <ChaCha20Poly1305 as KeySizeUser>::KeySize>,
28);
29
30impl Drop for PayloadKey {
31    fn drop(&mut self) {
32        self.0.as_mut_slice().zeroize();
33    }
34}
35
36/// The nonce used in age's STREAM encryption.
37///
38/// Structured as an 11 bytes of big endian counter, and 1 byte of last block flag
39/// (`0x00 / 0x01`). We store this in the lower 12 bytes of a `u128`.
40#[derive(Clone, Copy, Default)]
41struct Nonce(u128);
42
43impl Nonce {
44    /// Unsets last-chunk flag.
45    fn set_counter(&mut self, val: u64) {
46        self.0 = u128::from(val) << 8;
47    }
48
49    fn increment_counter(&mut self) {
50        // Increment the 11-byte counter
51        self.0 += 1 << 8;
52        if self.0 >> (8 * 12) != 0 {
53            panic!("We overflowed the nonce!");
54        }
55    }
56
57    fn is_last(&self) -> bool {
58        self.0 & 1 != 0
59    }
60
61    fn set_last(&mut self, last: bool) -> Result<(), ()> {
62        if !self.is_last() {
63            self.0 |= u128::from(last);
64            Ok(())
65        } else {
66            Err(())
67        }
68    }
69
70    fn to_bytes(self) -> [u8; 12] {
71        self.0.to_be_bytes()[4..]
72            .try_into()
73            .expect("slice is correct length")
74    }
75}
76
77#[cfg(feature = "async")]
78#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
79struct EncryptedChunk {
80    bytes: Vec<u8>,
81    offset: usize,
82}
83
84/// `STREAM[key](plaintext)`
85///
86/// The [STREAM] construction for online authenticated encryption, instantiated with
87/// ChaCha20-Poly1305 in 64KiB chunks, and a nonce structure of 11 bytes of big endian
88/// counter, and 1 byte of last block flag (0x00 / 0x01).
89///
90/// [STREAM]: https://eprint.iacr.org/2015/189.pdf
91pub(crate) struct Stream {
92    aead: ChaCha20Poly1305,
93    nonce: Nonce,
94}
95
96impl Stream {
97    fn new(key: PayloadKey) -> Self {
98        Stream {
99            aead: ChaCha20Poly1305::new(&key.0),
100            nonce: Nonce::default(),
101        }
102    }
103
104    /// Wraps `STREAM` encryption under the given `key` around a writer.
105    ///
106    /// `key` must **never** be repeated across multiple streams. In `age` this is
107    /// achieved by deriving the key with [`HKDF`] from both a random file key and a
108    /// random nonce.
109    ///
110    /// [`HKDF`]: anubis_core::primitives::hkdf
111    pub(crate) fn encrypt<W: Write>(key: PayloadKey, inner: W) -> StreamWriter<W> {
112        StreamWriter {
113            stream: Self::new(key),
114            inner,
115            chunk: Vec::with_capacity(CHUNK_SIZE),
116            #[cfg(feature = "async")]
117            encrypted_chunk: None,
118        }
119    }
120
121    /// Wraps `STREAM` encryption under the given `key` around a writer.
122    ///
123    /// `key` must **never** be repeated across multiple streams. In `age` this is
124    /// achieved by deriving the key with [`HKDF`] from both a random file key and a
125    /// random nonce.
126    ///
127    /// [`HKDF`]: anubis_core::primitives::hkdf
128    #[cfg(feature = "async")]
129    #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
130    pub(crate) fn encrypt_async<W: AsyncWrite>(key: PayloadKey, inner: W) -> StreamWriter<W> {
131        StreamWriter {
132            stream: Self::new(key),
133            inner,
134            chunk: Vec::with_capacity(CHUNK_SIZE),
135            encrypted_chunk: None,
136        }
137    }
138
139    /// Wraps `STREAM` decryption under the given `key` around a reader.
140    ///
141    /// `key` must **never** be repeated across multiple streams. In `age` this is
142    /// achieved by deriving the key with [`HKDF`] from both a random file key and a
143    /// random nonce.
144    ///
145    /// [`HKDF`]: anubis_core::primitives::hkdf
146    pub(crate) fn decrypt<R: Read>(key: PayloadKey, inner: R) -> StreamReader<R> {
147        StreamReader {
148            stream: Self::new(key),
149            inner,
150            encrypted_chunk: vec![0; ENCRYPTED_CHUNK_SIZE],
151            encrypted_pos: 0,
152            start: StartPos::Implicit(0),
153            plaintext_len: None,
154            cur_plaintext_pos: 0,
155            chunk: None,
156        }
157    }
158
159    /// Wraps `STREAM` decryption under the given `key` around a reader.
160    ///
161    /// `key` must **never** be repeated across multiple streams. In `age` this is
162    /// achieved by deriving the key with [`HKDF`] from both a random file key and a
163    /// random nonce.
164    ///
165    /// [`HKDF`]: anubis_core::primitives::hkdf
166    #[cfg(feature = "async")]
167    #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
168    pub(crate) fn decrypt_async<R: AsyncRead>(key: PayloadKey, inner: R) -> StreamReader<R> {
169        StreamReader {
170            stream: Self::new(key),
171            inner,
172            encrypted_chunk: vec![0; ENCRYPTED_CHUNK_SIZE],
173            encrypted_pos: 0,
174            start: StartPos::Implicit(0),
175            plaintext_len: None,
176            cur_plaintext_pos: 0,
177            chunk: None,
178        }
179    }
180
181    fn encrypt_chunk(&mut self, chunk: &[u8], last: bool) -> io::Result<Vec<u8>> {
182        assert!(chunk.len() <= CHUNK_SIZE);
183
184        self.nonce.set_last(last).map_err(|_| {
185            io::Error::new(io::ErrorKind::WriteZero, "last chunk has been processed")
186        })?;
187
188        let encrypted = self
189            .aead
190            .encrypt(&self.nonce.to_bytes().into(), chunk)
191            .expect("we will never hit chacha20::MAX_BLOCKS because of the chunk size");
192        self.nonce.increment_counter();
193
194        Ok(encrypted)
195    }
196
197    fn decrypt_chunk(&mut self, chunk: &[u8], last: bool) -> io::Result<SecretSlice<u8>> {
198        assert!(chunk.len() <= ENCRYPTED_CHUNK_SIZE);
199
200        self.nonce.set_last(last).map_err(|_| {
201            io::Error::new(io::ErrorKind::InvalidData, "last chunk has been processed")
202        })?;
203
204        let decrypted = self
205            .aead
206            .decrypt(&self.nonce.to_bytes().into(), chunk)
207            .map(SecretSlice::from)
208            .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "decryption error"))?;
209        self.nonce.increment_counter();
210
211        Ok(decrypted)
212    }
213
214    fn is_complete(&self) -> bool {
215        self.nonce.is_last()
216    }
217}
218
219/// Writes an encrypted age file.
220#[pin_project(project = StreamWriterProj)]
221pub struct StreamWriter<W> {
222    stream: Stream,
223    #[pin]
224    inner: W,
225    chunk: Vec<u8>,
226    #[cfg(feature = "async")]
227    #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
228    encrypted_chunk: Option<EncryptedChunk>,
229}
230
231impl<W: Write> StreamWriter<W> {
232    /// Writes the final chunk of the age file.
233    ///
234    /// You **MUST** call `finish` when you are done writing, in order to finish the
235    /// encryption process. Failing to call `finish` will result in a truncated file that
236    /// that will fail to decrypt.
237    pub fn finish(mut self) -> io::Result<W> {
238        let encrypted = self.stream.encrypt_chunk(&self.chunk, true)?;
239        self.inner.write_all(&encrypted)?;
240        Ok(self.inner)
241    }
242}
243
244impl<W: Write> Write for StreamWriter<W> {
245    fn write(&mut self, mut buf: &[u8]) -> io::Result<usize> {
246        let mut bytes_written = 0;
247
248        while !buf.is_empty() {
249            let to_write = cmp::min(CHUNK_SIZE - self.chunk.len(), buf.len());
250            self.chunk.extend_from_slice(&buf[..to_write]);
251            bytes_written += to_write;
252            buf = &buf[to_write..];
253
254            // At this point, either buf is empty, or we have a full chunk.
255            assert!(buf.is_empty() || self.chunk.len() == CHUNK_SIZE);
256
257            // Only encrypt the chunk if we have more data to write, as the last
258            // chunk must be written in finish().
259            if !buf.is_empty() {
260                let encrypted = self.stream.encrypt_chunk(&self.chunk, false)?;
261                self.inner.write_all(&encrypted)?;
262                self.chunk.clear();
263            }
264        }
265
266        Ok(bytes_written)
267    }
268
269    fn flush(&mut self) -> io::Result<()> {
270        self.inner.flush()
271    }
272}
273
274#[cfg(feature = "async")]
275#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
276impl<W: AsyncWrite> StreamWriter<W> {
277    fn poll_flush_chunk(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
278        let StreamWriterProj {
279            mut inner,
280            encrypted_chunk,
281            ..
282        } = self.project();
283
284        if let Some(chunk) = encrypted_chunk {
285            loop {
286                chunk.offset +=
287                    ready!(inner.as_mut().poll_write(cx, &chunk.bytes[chunk.offset..]))?;
288                if chunk.offset == chunk.bytes.len() {
289                    break;
290                }
291            }
292        }
293        *encrypted_chunk = None;
294
295        Poll::Ready(Ok(()))
296    }
297}
298
299#[cfg(feature = "async")]
300#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
301impl<W: AsyncWrite> AsyncWrite for StreamWriter<W> {
302    fn poll_write(
303        mut self: Pin<&mut Self>,
304        cx: &mut Context,
305        mut buf: &[u8],
306    ) -> Poll<io::Result<usize>> {
307        // If the buffer is empty, return immediately
308        if buf.is_empty() {
309            return Poll::Ready(Ok(0));
310        }
311
312        loop {
313            ready!(self.as_mut().poll_flush_chunk(cx))?;
314
315            // We can encounter one of three cases here:
316            // 1. `0 < buf.len() <= CHUNK_SIZE - self.chunk.len()`: we append to the
317            //    partial chunk and return. This may happen to complete the chunk.
318            // 2. `0 < CHUNK_SIZE - self.chunk.len() < buf.len()`: we consume part of
319            //    `buf` to complete the chunk, encrypt it, and return.
320            // 3. `0 == CHUNK_SIZE - self.chunk.len() < buf.len()`: we hit case 1 in a
321            //    previous invocation. We encrypt the chunk, and then loop around (where
322            //    we are guaranteed to hit case 1).
323            let to_write = cmp::min(CHUNK_SIZE - self.chunk.len(), buf.len());
324
325            self.as_mut()
326                .project()
327                .chunk
328                .extend_from_slice(&buf[..to_write]);
329            buf = &buf[to_write..];
330
331            // At this point, either buf is empty, or we have a full chunk.
332            assert!(buf.is_empty() || self.chunk.len() == CHUNK_SIZE);
333
334            // Only encrypt the chunk if we have more data to write, as the last
335            // chunk must be written in poll_close().
336            if !buf.is_empty() {
337                let this = self.as_mut().project();
338                *this.encrypted_chunk = Some(EncryptedChunk {
339                    bytes: this.stream.encrypt_chunk(this.chunk, false)?,
340                    offset: 0,
341                });
342                this.chunk.clear();
343            }
344
345            // If we wrote some data, return how much we wrote
346            if to_write > 0 {
347                return Poll::Ready(Ok(to_write));
348            }
349
350            // If we didn't write any data, loop and write some, to ensure
351            // this function does not return 0. This enables compatibility with
352            // futures::io::copy() and tokio::io::copy(), which will return a
353            // WriteZero error in that case.
354            // Since those functions copy 8K at a time, and CHUNK_SIZE is
355            // a multiple of 8K, this ends up happening once for each chunk
356            // after the first one
357        }
358    }
359
360    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
361        ready!(self.as_mut().poll_flush_chunk(cx))?;
362        self.project().inner.poll_flush(cx)
363    }
364
365    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
366        // Flush any remaining encrypted chunk bytes.
367        ready!(self.as_mut().poll_flush_chunk(cx))?;
368
369        if !self.stream.is_complete() {
370            // Finish the stream.
371            let this = self.as_mut().project();
372            *this.encrypted_chunk = Some(EncryptedChunk {
373                bytes: this.stream.encrypt_chunk(this.chunk, true)?,
374                offset: 0,
375            });
376        }
377
378        // Flush the final chunk (if we didn't in the first call).
379        ready!(self.as_mut().poll_flush_chunk(cx))?;
380        self.project().inner.poll_close(cx)
381    }
382}
383
384/// The position in the underlying reader corresponding to the start of the stream.
385///
386/// To impl Seek for StreamReader, we need to know the point in the reader corresponding
387/// to the first byte of the stream. But we can't query the reader for its current
388/// position without having a specific constructor for `R: Read + Seek`, which makes the
389/// higher-level API more complex. Instead, we count the number of bytes that have been
390/// read from the reader until we first need to seek, and then inside `impl Seek` we can
391/// query the reader's current position and figure out where the start was.
392enum StartPos {
393    /// An offset that we can subtract from the current position.
394    Implicit(u64),
395    /// The precise start position.
396    Explicit(u64),
397}
398
399/// Provides access to a decrypted age file.
400#[pin_project]
401pub struct StreamReader<R> {
402    stream: Stream,
403    #[pin]
404    inner: R,
405    encrypted_chunk: Vec<u8>,
406    encrypted_pos: usize,
407    start: StartPos,
408    plaintext_len: Option<u64>,
409    cur_plaintext_pos: u64,
410    chunk: Option<SecretSlice<u8>>,
411}
412
413impl<R> StreamReader<R> {
414    fn count_bytes(&mut self, read: usize) {
415        // We only need to count if we haven't yet worked out the start position.
416        if let StartPos::Implicit(offset) = &mut self.start {
417            *offset += read as u64;
418        }
419    }
420
421    fn decrypt_chunk(&mut self) -> io::Result<()> {
422        self.count_bytes(self.encrypted_pos);
423        let chunk = &self.encrypted_chunk[..self.encrypted_pos];
424
425        if chunk.is_empty() {
426            if !self.stream.is_complete() {
427                // Stream has ended before seeing the last chunk.
428                return Err(io::Error::new(
429                    io::ErrorKind::UnexpectedEof,
430                    "age file is truncated",
431                ));
432            }
433        } else {
434            // This check works for all cases except when the age file is an integer
435            // multiple of the chunk size. In that case, we try decrypting twice on a
436            // decryption failure.
437            let last = chunk.len() < ENCRYPTED_CHUNK_SIZE;
438
439            self.chunk = match (self.stream.decrypt_chunk(chunk, last), last) {
440                (Ok(chunk), _)
441                    if chunk.expose_secret().is_empty() && self.cur_plaintext_pos > 0 =>
442                {
443                    assert!(last);
444                    return Err(io::Error::new(
445                        io::ErrorKind::InvalidData,
446                        crate::fl!("err-stream-last-chunk-empty"),
447                    ));
448                }
449                (Ok(chunk), _) => Some(chunk),
450                (Err(_), false) => Some(self.stream.decrypt_chunk(chunk, true)?),
451                (Err(e), true) => return Err(e),
452            };
453        }
454
455        // We've finished with this encrypted chunk.
456        self.encrypted_pos = 0;
457
458        Ok(())
459    }
460
461    fn read_from_chunk(&mut self, buf: &mut [u8]) -> usize {
462        if self.chunk.is_none() {
463            return 0;
464        }
465
466        let chunk = self.chunk.as_ref().unwrap();
467        let cur_chunk_offset = self.cur_plaintext_pos as usize % CHUNK_SIZE;
468
469        let to_read = cmp::min(chunk.expose_secret().len() - cur_chunk_offset, buf.len());
470
471        buf[..to_read]
472            .copy_from_slice(&chunk.expose_secret()[cur_chunk_offset..cur_chunk_offset + to_read]);
473        self.cur_plaintext_pos += to_read as u64;
474        if self.cur_plaintext_pos % CHUNK_SIZE as u64 == 0 {
475            // We've finished with the current chunk.
476            self.chunk = None;
477        }
478
479        to_read
480    }
481}
482
483impl<R: Read> Read for StreamReader<R> {
484    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
485        if self.chunk.is_none() {
486            while self.encrypted_pos < ENCRYPTED_CHUNK_SIZE {
487                match self
488                    .inner
489                    .read(&mut self.encrypted_chunk[self.encrypted_pos..])
490                {
491                    Ok(0) => break,
492                    Ok(n) => self.encrypted_pos += n,
493                    Err(e) => match e.kind() {
494                        io::ErrorKind::Interrupted => (),
495                        _ => return Err(e),
496                    },
497                }
498            }
499            self.decrypt_chunk()?;
500        }
501
502        Ok(self.read_from_chunk(buf))
503    }
504}
505
506#[cfg(feature = "async")]
507#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
508impl<R: AsyncRead + Unpin> AsyncRead for StreamReader<R> {
509    fn poll_read(
510        mut self: Pin<&mut Self>,
511        cx: &mut Context,
512        buf: &mut [u8],
513    ) -> Poll<Result<usize, Error>> {
514        if self.chunk.is_none() {
515            while self.encrypted_pos < ENCRYPTED_CHUNK_SIZE {
516                let this = self.as_mut().project();
517                match ready!(this
518                    .inner
519                    .poll_read(cx, &mut this.encrypted_chunk[*this.encrypted_pos..]))
520                {
521                    Ok(0) => break,
522                    Ok(n) => self.encrypted_pos += n,
523                    Err(e) => match e.kind() {
524                        io::ErrorKind::Interrupted => (),
525                        _ => return Poll::Ready(Err(e)),
526                    },
527                }
528            }
529            self.decrypt_chunk()?;
530        }
531
532        Poll::Ready(Ok(self.read_from_chunk(buf)))
533    }
534}
535
536impl<R: Read + Seek> StreamReader<R> {
537    fn start(&mut self) -> io::Result<u64> {
538        match self.start {
539            StartPos::Implicit(offset) => {
540                let current = self.inner.stream_position()?;
541                let start = current - offset;
542
543                // Cache the start for future calls.
544                self.start = StartPos::Explicit(start);
545
546                Ok(start)
547            }
548            StartPos::Explicit(start) => Ok(start),
549        }
550    }
551
552    /// Returns the length of the plaintext
553    fn len(&mut self) -> io::Result<u64> {
554        match self.plaintext_len {
555            None => {
556                // Cache the current position and nonce, and then grab the start and end
557                // ciphertext positions.
558                let cur_pos = self.inner.stream_position()?;
559                let cur_nonce = self.stream.nonce.0;
560                let ct_start = self.start()?;
561                let ct_end = self.inner.seek(SeekFrom::End(0))?;
562                let ct_len = ct_end - ct_start;
563
564                // Use ceiling division to determine the number of chunks.
565                let num_chunks =
566                    (ct_len + (ENCRYPTED_CHUNK_SIZE as u64 - 1)) / ENCRYPTED_CHUNK_SIZE as u64;
567
568                // Authenticate the ciphertext length by checking that we can successfully
569                // decrypt the last chunk _as_ a last chunk.
570                let last_chunk_start = ct_start + ((num_chunks - 1) * ENCRYPTED_CHUNK_SIZE as u64);
571                let mut last_chunk = Vec::with_capacity((ct_end - last_chunk_start) as usize);
572                self.inner.seek(SeekFrom::Start(last_chunk_start))?;
573                self.inner.read_to_end(&mut last_chunk)?;
574                self.stream.nonce.set_counter(num_chunks - 1);
575                self.stream.decrypt_chunk(&last_chunk, true).map_err(|_| {
576                    io::Error::new(
577                        io::ErrorKind::InvalidData,
578                        "Last chunk is invalid, stream might be truncated",
579                    )
580                })?;
581
582                // Now that we have authenticated the ciphertext length, we can use it to
583                // calculate the plaintext length.
584                let total_tag_size = num_chunks * TAG_SIZE as u64;
585                let pt_len = ct_len - total_tag_size;
586
587                // Return to the original position and restore the nonce.
588                self.inner.seek(SeekFrom::Start(cur_pos))?;
589                self.stream.nonce = Nonce(cur_nonce);
590
591                // Cache the length for future calls.
592                self.plaintext_len = Some(pt_len);
593
594                Ok(pt_len)
595            }
596            Some(pt_len) => Ok(pt_len),
597        }
598    }
599}
600
601impl<R: Read + Seek> Seek for StreamReader<R> {
602    fn seek(&mut self, pos: SeekFrom) -> io::Result<u64> {
603        // Convert the offset into the target position within the plaintext
604        let start = self.start()?;
605        let target_pos = match pos {
606            SeekFrom::Start(offset) => offset,
607            SeekFrom::Current(offset) => {
608                let res = (self.cur_plaintext_pos as i64) + offset;
609                if res >= 0 {
610                    res as u64
611                } else {
612                    return Err(io::Error::new(
613                        io::ErrorKind::InvalidData,
614                        "cannot seek before the start",
615                    ));
616                }
617            }
618            SeekFrom::End(offset) => {
619                let res = (self.len()? as i64) + offset;
620                if res >= 0 {
621                    res as u64
622                } else {
623                    return Err(io::Error::new(
624                        io::ErrorKind::InvalidData,
625                        "cannot seek before the start",
626                    ));
627                }
628            }
629        };
630
631        let cur_chunk_index = self.cur_plaintext_pos / CHUNK_SIZE as u64;
632
633        let target_chunk_index = target_pos / CHUNK_SIZE as u64;
634        let target_chunk_offset = target_pos % CHUNK_SIZE as u64;
635
636        if target_chunk_index == cur_chunk_index {
637            // We just need to reposition ourselves within the current chunk.
638            self.cur_plaintext_pos = target_pos;
639        } else {
640            // Clear the current chunk
641            self.chunk = None;
642
643            // Seek to the beginning of the target chunk
644            self.inner.seek(SeekFrom::Start(
645                start + (target_chunk_index * ENCRYPTED_CHUNK_SIZE as u64),
646            ))?;
647            self.stream.nonce.set_counter(target_chunk_index);
648            self.cur_plaintext_pos = target_chunk_index * CHUNK_SIZE as u64;
649
650            // Read and drop bytes from the chunk to reach the target position.
651            if target_chunk_offset > 0 {
652                let mut to_drop = vec![0; target_chunk_offset as usize];
653                self.read_exact(&mut to_drop)?;
654            }
655            // We need to handle the edge case where the last chunk is not short, and
656            // `target_pos == self.len()` (so `target_chunk_index` points to the chunk
657            // after the last chunk). The next read would return no bytes, but because
658            // `target_chunk_offset == 0` we weren't forced to read past any in-chunk
659            // bytes, and thus have not set the `last` flag on the nonce.
660            //
661            // To handle this edge case, when `target_pos` is a multiple of the chunk
662            // size (i.e. this conditional branch), we compute the length of the
663            // plaintext. This is cached, so the overhead should be minimal.
664            else if target_pos == self.len()? {
665                self.stream
666                    .nonce
667                    .set_last(true)
668                    .expect("We unset the last chunk flag earlier");
669            }
670        }
671
672        // All done!
673        Ok(target_pos)
674    }
675}
676
677#[cfg(test)]
678mod tests {
679    use anubis_core::secrecy::ExposeSecret;
680    use std::io::{self, Cursor, Read, Seek, SeekFrom, Write};
681
682    use super::{PayloadKey, Stream, CHUNK_SIZE};
683
684    #[cfg(feature = "async")]
685    use futures::{
686        io::{AsyncRead, AsyncWrite},
687        pin_mut,
688        task::Poll,
689    };
690    #[cfg(feature = "async")]
691    use futures_test::task::noop_context;
692
693    #[test]
694    fn chunk_round_trip() {
695        let data = vec![42; CHUNK_SIZE];
696
697        let encrypted = {
698            let mut s = Stream::new(PayloadKey([7; 32].into()));
699            s.encrypt_chunk(&data, false).unwrap()
700        };
701
702        let decrypted = {
703            let mut s = Stream::new(PayloadKey([7; 32].into()));
704            s.decrypt_chunk(&encrypted, false).unwrap()
705        };
706
707        assert_eq!(decrypted.expose_secret(), &data);
708    }
709
710    #[test]
711    fn last_chunk_round_trip() {
712        let data = vec![42; CHUNK_SIZE];
713
714        let encrypted = {
715            let mut s = Stream::new(PayloadKey([7; 32].into()));
716            let res = s.encrypt_chunk(&data, true).unwrap();
717
718            // Further calls return an error
719            assert_eq!(
720                s.encrypt_chunk(&data, false).unwrap_err().kind(),
721                io::ErrorKind::WriteZero
722            );
723            assert_eq!(
724                s.encrypt_chunk(&data, true).unwrap_err().kind(),
725                io::ErrorKind::WriteZero
726            );
727
728            res
729        };
730
731        let decrypted = {
732            let mut s = Stream::new(PayloadKey([7; 32].into()));
733            let res = s.decrypt_chunk(&encrypted, true).unwrap();
734
735            // Further calls return an error
736            match s.decrypt_chunk(&encrypted, false) {
737                Err(e) => assert_eq!(e.kind(), io::ErrorKind::InvalidData),
738                _ => panic!("Expected error"),
739            }
740            match s.decrypt_chunk(&encrypted, true) {
741                Err(e) => assert_eq!(e.kind(), io::ErrorKind::InvalidData),
742                _ => panic!("Expected error"),
743            }
744
745            res
746        };
747
748        assert_eq!(decrypted.expose_secret(), &data);
749    }
750
751    fn stream_round_trip(data: &[u8]) {
752        let mut encrypted = vec![];
753        {
754            let mut w = Stream::encrypt(PayloadKey([7; 32].into()), &mut encrypted);
755            w.write_all(data).unwrap();
756            w.finish().unwrap();
757        };
758
759        let decrypted = {
760            let mut buf = vec![];
761            let mut r = Stream::decrypt(PayloadKey([7; 32].into()), &encrypted[..]);
762            r.read_to_end(&mut buf).unwrap();
763            buf
764        };
765
766        assert_eq!(decrypted, data);
767    }
768
769    /// Check that we can encrypt an empty slice.
770    ///
771    /// This is the sole exception to the "last chunk must be non-empty" rule.
772    #[test]
773    fn stream_round_trip_empty() {
774        stream_round_trip(&[]);
775    }
776
777    #[test]
778    fn stream_round_trip_short() {
779        stream_round_trip(&[42; 1024]);
780    }
781
782    #[test]
783    fn stream_round_trip_chunk() {
784        stream_round_trip(&[42; CHUNK_SIZE]);
785    }
786
787    #[test]
788    fn stream_round_trip_long() {
789        stream_round_trip(&[42; 100 * 1024]);
790    }
791
792    #[cfg(feature = "async")]
793    fn stream_async_round_trip(data: &[u8]) {
794        let mut encrypted = vec![];
795        {
796            let w = Stream::encrypt_async(PayloadKey([7; 32].into()), &mut encrypted);
797            pin_mut!(w);
798
799            let mut cx = noop_context();
800
801            let mut tmp = data;
802            loop {
803                match w.as_mut().poll_write(&mut cx, tmp) {
804                    Poll::Ready(Ok(0)) => break,
805                    Poll::Ready(Ok(written)) => tmp = &tmp[written..],
806                    Poll::Ready(Err(e)) => panic!("Unexpected error: {}", e),
807                    Poll::Pending => panic!("Unexpected Pending"),
808                }
809            }
810            loop {
811                match w.as_mut().poll_close(&mut cx) {
812                    Poll::Ready(Ok(())) => break,
813                    Poll::Ready(Err(e)) => panic!("Unexpected error: {}", e),
814                    Poll::Pending => panic!("Unexpected Pending"),
815                }
816            }
817        };
818
819        let decrypted = {
820            let mut buf = vec![];
821            let r = Stream::decrypt_async(PayloadKey([7; 32].into()), &encrypted[..]);
822            pin_mut!(r);
823
824            let mut cx = noop_context();
825
826            let mut tmp = [0; 4096];
827            loop {
828                match r.as_mut().poll_read(&mut cx, &mut tmp) {
829                    Poll::Ready(Ok(0)) => break buf,
830                    Poll::Ready(Ok(read)) => buf.extend_from_slice(&tmp[..read]),
831                    Poll::Ready(Err(e)) => panic!("Unexpected error: {}", e),
832                    Poll::Pending => panic!("Unexpected Pending"),
833                }
834            }
835        };
836
837        assert_eq!(decrypted, data);
838    }
839
840    #[cfg(feature = "async")]
841    #[test]
842    fn stream_async_round_trip_short() {
843        stream_async_round_trip(&[42; 1024]);
844    }
845
846    #[cfg(feature = "async")]
847    #[test]
848    fn stream_async_round_trip_chunk() {
849        stream_async_round_trip(&[42; CHUNK_SIZE]);
850    }
851
852    #[cfg(feature = "async")]
853    #[test]
854    fn stream_async_round_trip_long() {
855        stream_async_round_trip(&[42; 100 * 1024]);
856    }
857
858    #[cfg(feature = "async")]
859    fn stream_async_io_copy(data: &[u8]) {
860        use futures::AsyncWriteExt;
861
862        let runtime = tokio::runtime::Builder::new_current_thread()
863            .build()
864            .unwrap();
865        let mut encrypted = vec![];
866        let result = runtime.block_on(async {
867            let mut w = Stream::encrypt_async(PayloadKey([7; 32].into()), &mut encrypted);
868            match futures::io::copy(data, &mut w).await {
869                Ok(written) => {
870                    w.close().await.unwrap();
871                    Ok(written)
872                }
873                Err(e) => Err(e),
874            }
875        });
876
877        match result {
878            Ok(written) => assert_eq!(written, data.len() as u64),
879            Err(e) => panic!("Unexpected error: {}", e),
880        }
881
882        let decrypted = {
883            let mut buf = vec![];
884            let result = runtime.block_on(async {
885                let r = Stream::decrypt_async(PayloadKey([7; 32].into()), &encrypted[..]);
886                futures::io::copy(r, &mut buf).await
887            });
888
889            match result {
890                Ok(written) => assert_eq!(written, data.len() as u64),
891                Err(e) => panic!("Unexpected error: {}", e),
892            }
893
894            buf
895        };
896
897        assert_eq!(decrypted, data);
898    }
899
900    #[cfg(feature = "async")]
901    #[test]
902    fn stream_async_io_copy_short() {
903        stream_async_io_copy(&[42; 1024]);
904    }
905
906    #[cfg(feature = "async")]
907    #[test]
908    fn stream_async_io_copy_chunk() {
909        stream_async_io_copy(&[42; CHUNK_SIZE]);
910    }
911
912    #[cfg(feature = "async")]
913    #[test]
914    fn stream_async_io_copy_long() {
915        stream_async_io_copy(&[42; 100 * 1024]);
916    }
917
918    #[test]
919    fn stream_fails_to_decrypt_truncated_file() {
920        let data = vec![42; 2 * CHUNK_SIZE];
921
922        let mut encrypted = vec![];
923        {
924            let mut w = Stream::encrypt(PayloadKey([7; 32].into()), &mut encrypted);
925            w.write_all(&data).unwrap();
926            // Forget to call w.finish()!
927        };
928
929        let mut buf = vec![];
930        let mut r = Stream::decrypt(PayloadKey([7; 32].into()), &encrypted[..]);
931        assert_eq!(
932            r.read_to_end(&mut buf).unwrap_err().kind(),
933            io::ErrorKind::UnexpectedEof
934        );
935    }
936
937    #[test]
938    fn stream_seeking() {
939        let mut data = vec![0; 100 * 1024];
940        for (i, b) in data.iter_mut().enumerate() {
941            *b = i as u8;
942        }
943
944        let mut encrypted = vec![];
945        {
946            let mut w = Stream::encrypt(PayloadKey([7; 32].into()), &mut encrypted);
947            w.write_all(&data).unwrap();
948            w.finish().unwrap();
949        };
950
951        let mut r = Stream::decrypt(PayloadKey([7; 32].into()), Cursor::new(encrypted));
952
953        // Read through into the second chunk
954        let mut buf = vec![0; 100];
955        for i in 0..700 {
956            r.read_exact(&mut buf).unwrap();
957            assert_eq!(&buf[..], &data[100 * i..100 * (i + 1)]);
958        }
959
960        // Seek back into the first chunk
961        r.seek(SeekFrom::Start(250)).unwrap();
962        r.read_exact(&mut buf).unwrap();
963        assert_eq!(&buf[..], &data[250..350]);
964
965        // Seek forwards within this chunk
966        r.seek(SeekFrom::Current(510)).unwrap();
967        r.read_exact(&mut buf).unwrap();
968        assert_eq!(&buf[..], &data[860..960]);
969
970        // Seek backwards from the end
971        r.seek(SeekFrom::End(-1337)).unwrap();
972        r.read_exact(&mut buf).unwrap();
973        assert_eq!(&buf[..], &data[data.len() - 1337..data.len() - 1237]);
974    }
975
976    #[test]
977    fn seek_from_end_fails_on_truncation() {
978        // The plaintext is the string "hello" followed by 65536 zeros, just enough to
979        // give us some bytes to play with in the second chunk.
980        let mut plaintext: Vec<u8> = b"hello".to_vec();
981        plaintext.extend_from_slice(&[0; 65536]);
982
983        // Encrypt the plaintext just like the example code in the docs.
984        let mut encrypted = vec![];
985        {
986            let mut w = Stream::encrypt(PayloadKey([7; 32].into()), &mut encrypted);
987            w.write_all(&plaintext).unwrap();
988            w.finish().unwrap();
989        };
990
991        // First check the correct behavior of seeks relative to EOF. Create a decrypting
992        // reader, and move it one byte forward from the start, using SeekFrom::End.
993        // Confirm that reading 4 bytes from that point gives us "ello", as it should.
994        let mut reader = Stream::decrypt(PayloadKey([7; 32].into()), Cursor::new(&encrypted));
995        let eof_relative_offset = 1_i64 - plaintext.len() as i64;
996        reader.seek(SeekFrom::End(eof_relative_offset)).unwrap();
997        let mut buf = [0; 4];
998        reader.read_exact(&mut buf).unwrap();
999        assert_eq!(&buf, b"ello", "This is correct.");
1000
1001        // Do the same thing again, except this time truncate the ciphertext by one byte
1002        // first. This should cause some sort of error, instead of a successful read that
1003        // returns the wrong plaintext.
1004        let truncated_ciphertext = &encrypted[..encrypted.len() - 1];
1005        let mut truncated_reader = Stream::decrypt(
1006            PayloadKey([7; 32].into()),
1007            Cursor::new(truncated_ciphertext),
1008        );
1009        // Use the same seek target as above.
1010        match truncated_reader.seek(SeekFrom::End(eof_relative_offset)) {
1011            Err(e) => {
1012                assert_eq!(e.kind(), io::ErrorKind::InvalidData);
1013                assert_eq!(
1014                    &e.to_string(),
1015                    "Last chunk is invalid, stream might be truncated",
1016                );
1017            }
1018            Ok(_) => panic!("This is a security issue."),
1019        }
1020    }
1021
1022    #[test]
1023    fn seek_from_end_with_exact_chunk() {
1024        let plaintext: Vec<u8> = vec![42; 65536];
1025
1026        // Encrypt the plaintext just like the example code in the docs.
1027        let mut encrypted = vec![];
1028        {
1029            let mut w = Stream::encrypt(PayloadKey([7; 32].into()), &mut encrypted);
1030            w.write_all(&plaintext).unwrap();
1031            w.finish().unwrap();
1032        };
1033
1034        // Seek to the end of the plaintext before decrypting.
1035        let mut reader = Stream::decrypt(PayloadKey([7; 32].into()), Cursor::new(&encrypted));
1036        reader.seek(SeekFrom::End(0)).unwrap();
1037
1038        // Reading should return no bytes, because we're already at EOF.
1039        let mut buf = Vec::new();
1040        reader.read_to_end(&mut buf).unwrap();
1041        assert_eq!(buf.len(), 0);
1042    }
1043}