anubis_age/primitives/
stream.rs

1//! I/O helper structs for age file encryption and decryption.
2
3use aes_gcm_siv::{
4    aead::{generic_array::GenericArray, Aead, KeyInit, KeySizeUser},
5    Aes256GcmSiv,
6};
7use anubis_core::secrecy::{ExposeSecret, SecretSlice};
8use pin_project::pin_project;
9use std::cmp;
10use std::io::{self, Read, Seek, SeekFrom, Write};
11use zeroize::Zeroize;
12
13#[cfg(feature = "async")]
14use futures::{
15    io::{AsyncRead, AsyncWrite, Error},
16    ready,
17    task::{Context, Poll},
18};
19#[cfg(feature = "async")]
20use std::pin::Pin;
21
22const CHUNK_SIZE: usize = 64 * 1024;
23const TAG_SIZE: usize = 16;
24const ENCRYPTED_CHUNK_SIZE: usize = CHUNK_SIZE + TAG_SIZE;
25
26pub(crate) struct PayloadKey(pub(crate) GenericArray<u8, <Aes256GcmSiv as KeySizeUser>::KeySize>);
27
28impl Drop for PayloadKey {
29    fn drop(&mut self) {
30        self.0.as_mut_slice().zeroize();
31    }
32}
33
34/// The nonce used in age's STREAM encryption.
35///
36/// Structured as an 11 bytes of big endian counter, and 1 byte of last block flag
37/// (`0x00 / 0x01`). We store this in the lower 12 bytes of a `u128`.
38#[derive(Clone, Copy, Default)]
39struct Nonce(u128);
40
41impl Nonce {
42    /// Unsets last-chunk flag.
43    fn set_counter(&mut self, val: u64) {
44        self.0 = u128::from(val) << 8;
45    }
46
47    fn increment_counter(&mut self) -> Result<(), ()> {
48        // Increment the 11-byte counter
49        self.0 += 1 << 8;
50        if self.0 >> (8 * 12) != 0 {
51            // Nonce overflow - this would happen after 2^88 chunks (extremely unlikely)
52            Err(())
53        } else {
54            Ok(())
55        }
56    }
57
58    fn is_last(&self) -> bool {
59        self.0 & 1 != 0
60    }
61
62    fn set_last(&mut self, last: bool) -> Result<(), ()> {
63        if !self.is_last() {
64            self.0 |= u128::from(last);
65            Ok(())
66        } else {
67            Err(())
68        }
69    }
70
71    fn to_bytes(self) -> [u8; 12] {
72        self.0.to_be_bytes()[4..].try_into().unwrap_or_else(|_| {
73            // This should never happen as we're converting [u8; 16][4..] to [u8; 12]
74            [0u8; 12]
75        })
76    }
77}
78
79#[cfg(feature = "async")]
80#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
81struct EncryptedChunk {
82    bytes: Vec<u8>,
83    offset: usize,
84}
85
86/// `STREAM[key](plaintext)`
87///
88/// The [STREAM] construction for online authenticated encryption, instantiated with
89/// AES-256-GCM-SIV (RFC 8452) in 64KiB chunks, and a nonce structure of 11 bytes of big endian
90/// counter, and 1 byte of last block flag (0x00 / 0x01).
91///
92/// [STREAM]: https://eprint.iacr.org/2015/189.pdf
93pub(crate) struct Stream {
94    aead: Aes256GcmSiv,
95    nonce: Nonce,
96}
97
98impl Stream {
99    fn new(key: PayloadKey) -> Self {
100        Stream {
101            aead: Aes256GcmSiv::new(&key.0),
102            nonce: Nonce::default(),
103        }
104    }
105
106    /// Wraps `STREAM` encryption under the given `key` around a writer.
107    ///
108    /// `key` must **never** be repeated across multiple streams. In `age` this is
109    /// achieved by deriving the key with [`HKDF`] from both a random file key and a
110    /// random nonce.
111    ///
112    /// [`HKDF`]: anubis_core::primitives::hkdf
113    pub(crate) fn encrypt<W: Write>(key: PayloadKey, inner: W) -> StreamWriter<W> {
114        StreamWriter {
115            stream: Self::new(key),
116            inner,
117            chunk: Vec::with_capacity(CHUNK_SIZE),
118            #[cfg(feature = "async")]
119            encrypted_chunk: None,
120        }
121    }
122
123    /// Wraps `STREAM` encryption under the given `key` around a writer.
124    ///
125    /// `key` must **never** be repeated across multiple streams. In `age` this is
126    /// achieved by deriving the key with [`HKDF`] from both a random file key and a
127    /// random nonce.
128    ///
129    /// [`HKDF`]: anubis_core::primitives::hkdf
130    #[cfg(feature = "async")]
131    #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
132    pub(crate) fn encrypt_async<W: AsyncWrite>(key: PayloadKey, inner: W) -> StreamWriter<W> {
133        StreamWriter {
134            stream: Self::new(key),
135            inner,
136            chunk: Vec::with_capacity(CHUNK_SIZE),
137            encrypted_chunk: None,
138        }
139    }
140
141    /// Wraps `STREAM` decryption under the given `key` around a reader.
142    ///
143    /// `key` must **never** be repeated across multiple streams. In `age` this is
144    /// achieved by deriving the key with [`HKDF`] from both a random file key and a
145    /// random nonce.
146    ///
147    /// [`HKDF`]: anubis_core::primitives::hkdf
148    pub(crate) fn decrypt<R: Read>(key: PayloadKey, inner: R) -> StreamReader<R> {
149        StreamReader {
150            stream: Self::new(key),
151            inner,
152            encrypted_chunk: vec![0; ENCRYPTED_CHUNK_SIZE],
153            encrypted_pos: 0,
154            start: StartPos::Implicit(0),
155            plaintext_len: None,
156            cur_plaintext_pos: 0,
157            chunk: None,
158        }
159    }
160
161    /// Wraps `STREAM` decryption under the given `key` around a reader.
162    ///
163    /// `key` must **never** be repeated across multiple streams. In `age` this is
164    /// achieved by deriving the key with [`HKDF`] from both a random file key and a
165    /// random nonce.
166    ///
167    /// [`HKDF`]: anubis_core::primitives::hkdf
168    #[cfg(feature = "async")]
169    #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
170    pub(crate) fn decrypt_async<R: AsyncRead>(key: PayloadKey, inner: R) -> StreamReader<R> {
171        StreamReader {
172            stream: Self::new(key),
173            inner,
174            encrypted_chunk: vec![0; ENCRYPTED_CHUNK_SIZE],
175            encrypted_pos: 0,
176            start: StartPos::Implicit(0),
177            plaintext_len: None,
178            cur_plaintext_pos: 0,
179            chunk: None,
180        }
181    }
182
183    fn encrypt_chunk(&mut self, chunk: &[u8], last: bool) -> io::Result<Vec<u8>> {
184        assert!(chunk.len() <= CHUNK_SIZE);
185
186        self.nonce.set_last(last).map_err(|_| {
187            io::Error::new(io::ErrorKind::WriteZero, "last chunk has been processed")
188        })?;
189
190        let encrypted = self
191            .aead
192            .encrypt(&self.nonce.to_bytes().into(), chunk)
193            .map_err(|_| {
194                io::Error::new(
195                    io::ErrorKind::Other,
196                    "AES-256-GCM-SIV encryption failed unexpectedly",
197                )
198            })?;
199
200        self.nonce.increment_counter().map_err(|_| {
201            io::Error::new(
202                io::ErrorKind::WriteZero,
203                "nonce counter overflow (encrypted more than 2^88 chunks)",
204            )
205        })?;
206
207        Ok(encrypted)
208    }
209
210    fn decrypt_chunk(&mut self, chunk: &[u8], last: bool) -> io::Result<SecretSlice<u8>> {
211        assert!(chunk.len() <= ENCRYPTED_CHUNK_SIZE);
212
213        self.nonce.set_last(last).map_err(|_| {
214            io::Error::new(io::ErrorKind::InvalidData, "last chunk has been processed")
215        })?;
216
217        let decrypted = self
218            .aead
219            .decrypt(&self.nonce.to_bytes().into(), chunk)
220            .map(SecretSlice::from)
221            .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "decryption error"))?;
222
223        self.nonce.increment_counter().map_err(|_| {
224            io::Error::new(
225                io::ErrorKind::InvalidData,
226                "nonce counter overflow (decrypted more than 2^88 chunks)",
227            )
228        })?;
229
230        Ok(decrypted)
231    }
232
233    fn is_complete(&self) -> bool {
234        self.nonce.is_last()
235    }
236}
237
238/// Writes an encrypted age file.
239#[pin_project(project = StreamWriterProj)]
240pub struct StreamWriter<W> {
241    stream: Stream,
242    #[pin]
243    inner: W,
244    chunk: Vec<u8>,
245    #[cfg(feature = "async")]
246    #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
247    encrypted_chunk: Option<EncryptedChunk>,
248}
249
250impl<W: Write> StreamWriter<W> {
251    /// Writes the final chunk of the age file.
252    ///
253    /// You **MUST** call `finish` when you are done writing, in order to finish the
254    /// encryption process. Failing to call `finish` will result in a truncated file that
255    /// that will fail to decrypt.
256    pub fn finish(mut self) -> io::Result<W> {
257        let encrypted = self.stream.encrypt_chunk(&self.chunk, true)?;
258        self.inner.write_all(&encrypted)?;
259        Ok(self.inner)
260    }
261}
262
263impl<W: Write> Write for StreamWriter<W> {
264    fn write(&mut self, mut buf: &[u8]) -> io::Result<usize> {
265        let mut bytes_written = 0;
266
267        while !buf.is_empty() {
268            let to_write = cmp::min(CHUNK_SIZE - self.chunk.len(), buf.len());
269            self.chunk.extend_from_slice(&buf[..to_write]);
270            bytes_written += to_write;
271            buf = &buf[to_write..];
272
273            // At this point, either buf is empty, or we have a full chunk.
274            assert!(buf.is_empty() || self.chunk.len() == CHUNK_SIZE);
275
276            // Only encrypt the chunk if we have more data to write, as the last
277            // chunk must be written in finish().
278            if !buf.is_empty() {
279                let encrypted = self.stream.encrypt_chunk(&self.chunk, false)?;
280                self.inner.write_all(&encrypted)?;
281                self.chunk.clear();
282            }
283        }
284
285        Ok(bytes_written)
286    }
287
288    fn flush(&mut self) -> io::Result<()> {
289        self.inner.flush()
290    }
291}
292
293#[cfg(feature = "async")]
294#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
295impl<W: AsyncWrite> StreamWriter<W> {
296    fn poll_flush_chunk(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
297        let StreamWriterProj {
298            mut inner,
299            encrypted_chunk,
300            ..
301        } = self.project();
302
303        if let Some(chunk) = encrypted_chunk {
304            loop {
305                chunk.offset +=
306                    ready!(inner.as_mut().poll_write(cx, &chunk.bytes[chunk.offset..]))?;
307                if chunk.offset == chunk.bytes.len() {
308                    break;
309                }
310            }
311        }
312        *encrypted_chunk = None;
313
314        Poll::Ready(Ok(()))
315    }
316}
317
318#[cfg(feature = "async")]
319#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
320impl<W: AsyncWrite> AsyncWrite for StreamWriter<W> {
321    fn poll_write(
322        mut self: Pin<&mut Self>,
323        cx: &mut Context,
324        mut buf: &[u8],
325    ) -> Poll<io::Result<usize>> {
326        // If the buffer is empty, return immediately
327        if buf.is_empty() {
328            return Poll::Ready(Ok(0));
329        }
330
331        loop {
332            ready!(self.as_mut().poll_flush_chunk(cx))?;
333
334            // We can encounter one of three cases here:
335            // 1. `0 < buf.len() <= CHUNK_SIZE - self.chunk.len()`: we append to the
336            //    partial chunk and return. This may happen to complete the chunk.
337            // 2. `0 < CHUNK_SIZE - self.chunk.len() < buf.len()`: we consume part of
338            //    `buf` to complete the chunk, encrypt it, and return.
339            // 3. `0 == CHUNK_SIZE - self.chunk.len() < buf.len()`: we hit case 1 in a
340            //    previous invocation. We encrypt the chunk, and then loop around (where
341            //    we are guaranteed to hit case 1).
342            let to_write = cmp::min(CHUNK_SIZE - self.chunk.len(), buf.len());
343
344            self.as_mut()
345                .project()
346                .chunk
347                .extend_from_slice(&buf[..to_write]);
348            buf = &buf[to_write..];
349
350            // At this point, either buf is empty, or we have a full chunk.
351            assert!(buf.is_empty() || self.chunk.len() == CHUNK_SIZE);
352
353            // Only encrypt the chunk if we have more data to write, as the last
354            // chunk must be written in poll_close().
355            if !buf.is_empty() {
356                let this = self.as_mut().project();
357                *this.encrypted_chunk = Some(EncryptedChunk {
358                    bytes: this.stream.encrypt_chunk(this.chunk, false)?,
359                    offset: 0,
360                });
361                this.chunk.clear();
362            }
363
364            // If we wrote some data, return how much we wrote
365            if to_write > 0 {
366                return Poll::Ready(Ok(to_write));
367            }
368
369            // If we didn't write any data, loop and write some, to ensure
370            // this function does not return 0. This enables compatibility with
371            // futures::io::copy() and tokio::io::copy(), which will return a
372            // WriteZero error in that case.
373            // Since those functions copy 8K at a time, and CHUNK_SIZE is
374            // a multiple of 8K, this ends up happening once for each chunk
375            // after the first one
376        }
377    }
378
379    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
380        ready!(self.as_mut().poll_flush_chunk(cx))?;
381        self.project().inner.poll_flush(cx)
382    }
383
384    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
385        // Flush any remaining encrypted chunk bytes.
386        ready!(self.as_mut().poll_flush_chunk(cx))?;
387
388        if !self.stream.is_complete() {
389            // Finish the stream.
390            let this = self.as_mut().project();
391            *this.encrypted_chunk = Some(EncryptedChunk {
392                bytes: this.stream.encrypt_chunk(this.chunk, true)?,
393                offset: 0,
394            });
395        }
396
397        // Flush the final chunk (if we didn't in the first call).
398        ready!(self.as_mut().poll_flush_chunk(cx))?;
399        self.project().inner.poll_close(cx)
400    }
401}
402
403/// The position in the underlying reader corresponding to the start of the stream.
404///
405/// To impl Seek for StreamReader, we need to know the point in the reader corresponding
406/// to the first byte of the stream. But we can't query the reader for its current
407/// position without having a specific constructor for `R: Read + Seek`, which makes the
408/// higher-level API more complex. Instead, we count the number of bytes that have been
409/// read from the reader until we first need to seek, and then inside `impl Seek` we can
410/// query the reader's current position and figure out where the start was.
411enum StartPos {
412    /// An offset that we can subtract from the current position.
413    Implicit(u64),
414    /// The precise start position.
415    Explicit(u64),
416}
417
418/// Provides access to a decrypted age file.
419#[pin_project]
420pub struct StreamReader<R> {
421    stream: Stream,
422    #[pin]
423    inner: R,
424    encrypted_chunk: Vec<u8>,
425    encrypted_pos: usize,
426    start: StartPos,
427    plaintext_len: Option<u64>,
428    cur_plaintext_pos: u64,
429    chunk: Option<SecretSlice<u8>>,
430}
431
432impl<R> StreamReader<R> {
433    fn count_bytes(&mut self, read: usize) {
434        // We only need to count if we haven't yet worked out the start position.
435        if let StartPos::Implicit(offset) = &mut self.start {
436            *offset += read as u64;
437        }
438    }
439
440    fn decrypt_chunk(&mut self) -> io::Result<()> {
441        self.count_bytes(self.encrypted_pos);
442        let chunk = &self.encrypted_chunk[..self.encrypted_pos];
443
444        if chunk.is_empty() {
445            if !self.stream.is_complete() {
446                // Stream has ended before seeing the last chunk.
447                return Err(io::Error::new(
448                    io::ErrorKind::UnexpectedEof,
449                    "age file is truncated",
450                ));
451            }
452        } else {
453            // This check works for all cases except when the age file is an integer
454            // multiple of the chunk size. In that case, we try decrypting twice on a
455            // decryption failure.
456            let last = chunk.len() < ENCRYPTED_CHUNK_SIZE;
457
458            self.chunk = match (self.stream.decrypt_chunk(chunk, last), last) {
459                (Ok(chunk), _)
460                    if chunk.expose_secret().is_empty() && self.cur_plaintext_pos > 0 =>
461                {
462                    assert!(last);
463                    return Err(io::Error::new(
464                        io::ErrorKind::InvalidData,
465                        crate::fl!("err-stream-last-chunk-empty"),
466                    ));
467                }
468                (Ok(chunk), _) => Some(chunk),
469                (Err(_), false) => Some(self.stream.decrypt_chunk(chunk, true)?),
470                (Err(e), true) => return Err(e),
471            };
472        }
473
474        // We've finished with this encrypted chunk.
475        self.encrypted_pos = 0;
476
477        Ok(())
478    }
479
480    fn read_from_chunk(&mut self, buf: &mut [u8]) -> usize {
481        if self.chunk.is_none() {
482            return 0;
483        }
484
485        let chunk = self.chunk.as_ref().unwrap();
486        let cur_chunk_offset = self.cur_plaintext_pos as usize % CHUNK_SIZE;
487
488        let to_read = cmp::min(chunk.expose_secret().len() - cur_chunk_offset, buf.len());
489
490        buf[..to_read]
491            .copy_from_slice(&chunk.expose_secret()[cur_chunk_offset..cur_chunk_offset + to_read]);
492        self.cur_plaintext_pos += to_read as u64;
493        if self.cur_plaintext_pos % CHUNK_SIZE as u64 == 0 {
494            // We've finished with the current chunk.
495            self.chunk = None;
496        }
497
498        to_read
499    }
500}
501
502impl<R: Read> Read for StreamReader<R> {
503    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
504        if self.chunk.is_none() {
505            while self.encrypted_pos < ENCRYPTED_CHUNK_SIZE {
506                match self
507                    .inner
508                    .read(&mut self.encrypted_chunk[self.encrypted_pos..])
509                {
510                    Ok(0) => break,
511                    Ok(n) => self.encrypted_pos += n,
512                    Err(e) => match e.kind() {
513                        io::ErrorKind::Interrupted => (),
514                        _ => return Err(e),
515                    },
516                }
517            }
518            self.decrypt_chunk()?;
519        }
520
521        Ok(self.read_from_chunk(buf))
522    }
523}
524
525#[cfg(feature = "async")]
526#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
527impl<R: AsyncRead + Unpin> AsyncRead for StreamReader<R> {
528    fn poll_read(
529        mut self: Pin<&mut Self>,
530        cx: &mut Context,
531        buf: &mut [u8],
532    ) -> Poll<Result<usize, Error>> {
533        if self.chunk.is_none() {
534            while self.encrypted_pos < ENCRYPTED_CHUNK_SIZE {
535                let this = self.as_mut().project();
536                match ready!(this
537                    .inner
538                    .poll_read(cx, &mut this.encrypted_chunk[*this.encrypted_pos..]))
539                {
540                    Ok(0) => break,
541                    Ok(n) => self.encrypted_pos += n,
542                    Err(e) => match e.kind() {
543                        io::ErrorKind::Interrupted => (),
544                        _ => return Poll::Ready(Err(e)),
545                    },
546                }
547            }
548            self.decrypt_chunk()?;
549        }
550
551        Poll::Ready(Ok(self.read_from_chunk(buf)))
552    }
553}
554
555impl<R: Read + Seek> StreamReader<R> {
556    fn start(&mut self) -> io::Result<u64> {
557        match self.start {
558            StartPos::Implicit(offset) => {
559                let current = self.inner.stream_position()?;
560                let start = current - offset;
561
562                // Cache the start for future calls.
563                self.start = StartPos::Explicit(start);
564
565                Ok(start)
566            }
567            StartPos::Explicit(start) => Ok(start),
568        }
569    }
570
571    /// Returns the length of the plaintext
572    fn len(&mut self) -> io::Result<u64> {
573        match self.plaintext_len {
574            None => {
575                // Cache the current position and nonce, and then grab the start and end
576                // ciphertext positions.
577                let cur_pos = self.inner.stream_position()?;
578                let cur_nonce = self.stream.nonce.0;
579                let ct_start = self.start()?;
580                let ct_end = self.inner.seek(SeekFrom::End(0))?;
581                let ct_len = ct_end - ct_start;
582
583                // Use ceiling division to determine the number of chunks.
584                let num_chunks =
585                    (ct_len + (ENCRYPTED_CHUNK_SIZE as u64 - 1)) / ENCRYPTED_CHUNK_SIZE as u64;
586
587                // Authenticate the ciphertext length by checking that we can successfully
588                // decrypt the last chunk _as_ a last chunk.
589                let last_chunk_start = ct_start + ((num_chunks - 1) * ENCRYPTED_CHUNK_SIZE as u64);
590                let mut last_chunk = Vec::with_capacity((ct_end - last_chunk_start) as usize);
591                self.inner.seek(SeekFrom::Start(last_chunk_start))?;
592                self.inner.read_to_end(&mut last_chunk)?;
593                self.stream.nonce.set_counter(num_chunks - 1);
594                self.stream.decrypt_chunk(&last_chunk, true).map_err(|_| {
595                    io::Error::new(
596                        io::ErrorKind::InvalidData,
597                        "Last chunk is invalid, stream might be truncated",
598                    )
599                })?;
600
601                // Now that we have authenticated the ciphertext length, we can use it to
602                // calculate the plaintext length.
603                let total_tag_size = num_chunks * TAG_SIZE as u64;
604                let pt_len = ct_len - total_tag_size;
605
606                // Return to the original position and restore the nonce.
607                self.inner.seek(SeekFrom::Start(cur_pos))?;
608                self.stream.nonce = Nonce(cur_nonce);
609
610                // Cache the length for future calls.
611                self.plaintext_len = Some(pt_len);
612
613                Ok(pt_len)
614            }
615            Some(pt_len) => Ok(pt_len),
616        }
617    }
618}
619
620impl<R: Read + Seek> Seek for StreamReader<R> {
621    fn seek(&mut self, pos: SeekFrom) -> io::Result<u64> {
622        // Convert the offset into the target position within the plaintext
623        let start = self.start()?;
624        let target_pos = match pos {
625            SeekFrom::Start(offset) => offset,
626            SeekFrom::Current(offset) => {
627                let res = (self.cur_plaintext_pos as i64) + offset;
628                if res >= 0 {
629                    res as u64
630                } else {
631                    return Err(io::Error::new(
632                        io::ErrorKind::InvalidData,
633                        "cannot seek before the start",
634                    ));
635                }
636            }
637            SeekFrom::End(offset) => {
638                let res = (self.len()? as i64) + offset;
639                if res >= 0 {
640                    res as u64
641                } else {
642                    return Err(io::Error::new(
643                        io::ErrorKind::InvalidData,
644                        "cannot seek before the start",
645                    ));
646                }
647            }
648        };
649
650        let cur_chunk_index = self.cur_plaintext_pos / CHUNK_SIZE as u64;
651
652        let target_chunk_index = target_pos / CHUNK_SIZE as u64;
653        let target_chunk_offset = target_pos % CHUNK_SIZE as u64;
654
655        if target_chunk_index == cur_chunk_index {
656            // We just need to reposition ourselves within the current chunk.
657            self.cur_plaintext_pos = target_pos;
658        } else {
659            // Clear the current chunk
660            self.chunk = None;
661
662            // Seek to the beginning of the target chunk
663            self.inner.seek(SeekFrom::Start(
664                start + (target_chunk_index * ENCRYPTED_CHUNK_SIZE as u64),
665            ))?;
666            self.stream.nonce.set_counter(target_chunk_index);
667            self.cur_plaintext_pos = target_chunk_index * CHUNK_SIZE as u64;
668
669            // Read and drop bytes from the chunk to reach the target position.
670            if target_chunk_offset > 0 {
671                let mut to_drop = vec![0; target_chunk_offset as usize];
672                self.read_exact(&mut to_drop)?;
673            }
674            // We need to handle the edge case where the last chunk is not short, and
675            // `target_pos == self.len()` (so `target_chunk_index` points to the chunk
676            // after the last chunk). The next read would return no bytes, but because
677            // `target_chunk_offset == 0` we weren't forced to read past any in-chunk
678            // bytes, and thus have not set the `last` flag on the nonce.
679            //
680            // To handle this edge case, when `target_pos` is a multiple of the chunk
681            // size (i.e. this conditional branch), we compute the length of the
682            // plaintext. This is cached, so the overhead should be minimal.
683            else if target_pos == self.len()? {
684                self.stream.nonce.set_last(true).map_err(|_| {
685                    io::Error::new(
686                        io::ErrorKind::InvalidData,
687                        "nonce is already set as last chunk",
688                    )
689                })?;
690            }
691        }
692
693        // All done!
694        Ok(target_pos)
695    }
696}
697
698#[cfg(test)]
699mod tests {
700    use anubis_core::secrecy::ExposeSecret;
701    use std::io::{self, Cursor, Read, Seek, SeekFrom, Write};
702
703    use super::{PayloadKey, Stream, CHUNK_SIZE};
704
705    #[cfg(feature = "async")]
706    use futures::{
707        io::{AsyncRead, AsyncWrite},
708        pin_mut,
709        task::Poll,
710    };
711    #[cfg(feature = "async")]
712    use futures_test::task::noop_context;
713
714    #[test]
715    fn chunk_round_trip() {
716        let data = vec![42; CHUNK_SIZE];
717
718        let encrypted = {
719            let mut s = Stream::new(PayloadKey([7; 32].into()));
720            s.encrypt_chunk(&data, false).unwrap()
721        };
722
723        let decrypted = {
724            let mut s = Stream::new(PayloadKey([7; 32].into()));
725            s.decrypt_chunk(&encrypted, false).unwrap()
726        };
727
728        assert_eq!(decrypted.expose_secret(), &data);
729    }
730
731    #[test]
732    fn last_chunk_round_trip() {
733        let data = vec![42; CHUNK_SIZE];
734
735        let encrypted = {
736            let mut s = Stream::new(PayloadKey([7; 32].into()));
737            let res = s.encrypt_chunk(&data, true).unwrap();
738
739            // Further calls return an error
740            assert_eq!(
741                s.encrypt_chunk(&data, false).unwrap_err().kind(),
742                io::ErrorKind::WriteZero
743            );
744            assert_eq!(
745                s.encrypt_chunk(&data, true).unwrap_err().kind(),
746                io::ErrorKind::WriteZero
747            );
748
749            res
750        };
751
752        let decrypted = {
753            let mut s = Stream::new(PayloadKey([7; 32].into()));
754            let res = s.decrypt_chunk(&encrypted, true).unwrap();
755
756            // Further calls return an error
757            match s.decrypt_chunk(&encrypted, false) {
758                Err(e) => assert_eq!(e.kind(), io::ErrorKind::InvalidData),
759                _ => panic!("Expected error"),
760            }
761            match s.decrypt_chunk(&encrypted, true) {
762                Err(e) => assert_eq!(e.kind(), io::ErrorKind::InvalidData),
763                _ => panic!("Expected error"),
764            }
765
766            res
767        };
768
769        assert_eq!(decrypted.expose_secret(), &data);
770    }
771
772    fn stream_round_trip(data: &[u8]) {
773        let mut encrypted = vec![];
774        {
775            let mut w = Stream::encrypt(PayloadKey([7; 32].into()), &mut encrypted);
776            w.write_all(data).unwrap();
777            w.finish().unwrap();
778        };
779
780        let decrypted = {
781            let mut buf = vec![];
782            let mut r = Stream::decrypt(PayloadKey([7; 32].into()), &encrypted[..]);
783            r.read_to_end(&mut buf).unwrap();
784            buf
785        };
786
787        assert_eq!(decrypted, data);
788    }
789
790    /// Check that we can encrypt an empty slice.
791    ///
792    /// This is the sole exception to the "last chunk must be non-empty" rule.
793    #[test]
794    fn stream_round_trip_empty() {
795        stream_round_trip(&[]);
796    }
797
798    #[test]
799    fn stream_round_trip_short() {
800        stream_round_trip(&[42; 1024]);
801    }
802
803    #[test]
804    fn stream_round_trip_chunk() {
805        stream_round_trip(&[42; CHUNK_SIZE]);
806    }
807
808    #[test]
809    fn stream_round_trip_long() {
810        stream_round_trip(&[42; 100 * 1024]);
811    }
812
813    #[cfg(feature = "async")]
814    fn stream_async_round_trip(data: &[u8]) {
815        let mut encrypted = vec![];
816        {
817            let w = Stream::encrypt_async(PayloadKey([7; 32].into()), &mut encrypted);
818            pin_mut!(w);
819
820            let mut cx = noop_context();
821
822            let mut tmp = data;
823            loop {
824                match w.as_mut().poll_write(&mut cx, tmp) {
825                    Poll::Ready(Ok(0)) => break,
826                    Poll::Ready(Ok(written)) => tmp = &tmp[written..],
827                    Poll::Ready(Err(e)) => panic!("Unexpected error: {}", e),
828                    Poll::Pending => panic!("Unexpected Pending"),
829                }
830            }
831            loop {
832                match w.as_mut().poll_close(&mut cx) {
833                    Poll::Ready(Ok(())) => break,
834                    Poll::Ready(Err(e)) => panic!("Unexpected error: {}", e),
835                    Poll::Pending => panic!("Unexpected Pending"),
836                }
837            }
838        };
839
840        let decrypted = {
841            let mut buf = vec![];
842            let r = Stream::decrypt_async(PayloadKey([7; 32].into()), &encrypted[..]);
843            pin_mut!(r);
844
845            let mut cx = noop_context();
846
847            let mut tmp = [0; 4096];
848            loop {
849                match r.as_mut().poll_read(&mut cx, &mut tmp) {
850                    Poll::Ready(Ok(0)) => break buf,
851                    Poll::Ready(Ok(read)) => buf.extend_from_slice(&tmp[..read]),
852                    Poll::Ready(Err(e)) => panic!("Unexpected error: {}", e),
853                    Poll::Pending => panic!("Unexpected Pending"),
854                }
855            }
856        };
857
858        assert_eq!(decrypted, data);
859    }
860
861    #[cfg(feature = "async")]
862    #[test]
863    fn stream_async_round_trip_short() {
864        stream_async_round_trip(&[42; 1024]);
865    }
866
867    #[cfg(feature = "async")]
868    #[test]
869    fn stream_async_round_trip_chunk() {
870        stream_async_round_trip(&[42; CHUNK_SIZE]);
871    }
872
873    #[cfg(feature = "async")]
874    #[test]
875    fn stream_async_round_trip_long() {
876        stream_async_round_trip(&[42; 100 * 1024]);
877    }
878
879    #[cfg(feature = "async")]
880    fn stream_async_io_copy(data: &[u8]) {
881        use futures::AsyncWriteExt;
882
883        let runtime = tokio::runtime::Builder::new_current_thread()
884            .build()
885            .unwrap();
886        let mut encrypted = vec![];
887        let result = runtime.block_on(async {
888            let mut w = Stream::encrypt_async(PayloadKey([7; 32].into()), &mut encrypted);
889            match futures::io::copy(data, &mut w).await {
890                Ok(written) => {
891                    w.close().await.unwrap();
892                    Ok(written)
893                }
894                Err(e) => Err(e),
895            }
896        });
897
898        match result {
899            Ok(written) => assert_eq!(written, data.len() as u64),
900            Err(e) => panic!("Unexpected error: {}", e),
901        }
902
903        let decrypted = {
904            let mut buf = vec![];
905            let result = runtime.block_on(async {
906                let r = Stream::decrypt_async(PayloadKey([7; 32].into()), &encrypted[..]);
907                futures::io::copy(r, &mut buf).await
908            });
909
910            match result {
911                Ok(written) => assert_eq!(written, data.len() as u64),
912                Err(e) => panic!("Unexpected error: {}", e),
913            }
914
915            buf
916        };
917
918        assert_eq!(decrypted, data);
919    }
920
921    #[cfg(feature = "async")]
922    #[test]
923    fn stream_async_io_copy_short() {
924        stream_async_io_copy(&[42; 1024]);
925    }
926
927    #[cfg(feature = "async")]
928    #[test]
929    fn stream_async_io_copy_chunk() {
930        stream_async_io_copy(&[42; CHUNK_SIZE]);
931    }
932
933    #[cfg(feature = "async")]
934    #[test]
935    fn stream_async_io_copy_long() {
936        stream_async_io_copy(&[42; 100 * 1024]);
937    }
938
939    #[test]
940    fn stream_fails_to_decrypt_truncated_file() {
941        let data = vec![42; 2 * CHUNK_SIZE];
942
943        let mut encrypted = vec![];
944        {
945            let mut w = Stream::encrypt(PayloadKey([7; 32].into()), &mut encrypted);
946            w.write_all(&data).unwrap();
947            // Forget to call w.finish()!
948        };
949
950        let mut buf = vec![];
951        let mut r = Stream::decrypt(PayloadKey([7; 32].into()), &encrypted[..]);
952        assert_eq!(
953            r.read_to_end(&mut buf).unwrap_err().kind(),
954            io::ErrorKind::UnexpectedEof
955        );
956    }
957
958    #[test]
959    fn stream_seeking() {
960        let mut data = vec![0; 100 * 1024];
961        for (i, b) in data.iter_mut().enumerate() {
962            *b = i as u8;
963        }
964
965        let mut encrypted = vec![];
966        {
967            let mut w = Stream::encrypt(PayloadKey([7; 32].into()), &mut encrypted);
968            w.write_all(&data).unwrap();
969            w.finish().unwrap();
970        };
971
972        let mut r = Stream::decrypt(PayloadKey([7; 32].into()), Cursor::new(encrypted));
973
974        // Read through into the second chunk
975        let mut buf = vec![0; 100];
976        for i in 0..700 {
977            r.read_exact(&mut buf).unwrap();
978            assert_eq!(&buf[..], &data[100 * i..100 * (i + 1)]);
979        }
980
981        // Seek back into the first chunk
982        r.seek(SeekFrom::Start(250)).unwrap();
983        r.read_exact(&mut buf).unwrap();
984        assert_eq!(&buf[..], &data[250..350]);
985
986        // Seek forwards within this chunk
987        r.seek(SeekFrom::Current(510)).unwrap();
988        r.read_exact(&mut buf).unwrap();
989        assert_eq!(&buf[..], &data[860..960]);
990
991        // Seek backwards from the end
992        r.seek(SeekFrom::End(-1337)).unwrap();
993        r.read_exact(&mut buf).unwrap();
994        assert_eq!(&buf[..], &data[data.len() - 1337..data.len() - 1237]);
995    }
996
997    #[test]
998    fn seek_from_end_fails_on_truncation() {
999        // The plaintext is the string "hello" followed by 65536 zeros, just enough to
1000        // give us some bytes to play with in the second chunk.
1001        let mut plaintext: Vec<u8> = b"hello".to_vec();
1002        plaintext.extend_from_slice(&[0; 65536]);
1003
1004        // Encrypt the plaintext just like the example code in the docs.
1005        let mut encrypted = vec![];
1006        {
1007            let mut w = Stream::encrypt(PayloadKey([7; 32].into()), &mut encrypted);
1008            w.write_all(&plaintext).unwrap();
1009            w.finish().unwrap();
1010        };
1011
1012        // First check the correct behavior of seeks relative to EOF. Create a decrypting
1013        // reader, and move it one byte forward from the start, using SeekFrom::End.
1014        // Confirm that reading 4 bytes from that point gives us "ello", as it should.
1015        let mut reader = Stream::decrypt(PayloadKey([7; 32].into()), Cursor::new(&encrypted));
1016        let eof_relative_offset = 1_i64 - plaintext.len() as i64;
1017        reader.seek(SeekFrom::End(eof_relative_offset)).unwrap();
1018        let mut buf = [0; 4];
1019        reader.read_exact(&mut buf).unwrap();
1020        assert_eq!(&buf, b"ello", "This is correct.");
1021
1022        // Do the same thing again, except this time truncate the ciphertext by one byte
1023        // first. This should cause some sort of error, instead of a successful read that
1024        // returns the wrong plaintext.
1025        let truncated_ciphertext = &encrypted[..encrypted.len() - 1];
1026        let mut truncated_reader = Stream::decrypt(
1027            PayloadKey([7; 32].into()),
1028            Cursor::new(truncated_ciphertext),
1029        );
1030        // Use the same seek target as above.
1031        match truncated_reader.seek(SeekFrom::End(eof_relative_offset)) {
1032            Err(e) => {
1033                assert_eq!(e.kind(), io::ErrorKind::InvalidData);
1034                assert_eq!(
1035                    &e.to_string(),
1036                    "Last chunk is invalid, stream might be truncated",
1037                );
1038            }
1039            Ok(_) => panic!("This is a security issue."),
1040        }
1041    }
1042
1043    #[test]
1044    fn seek_from_end_with_exact_chunk() {
1045        let plaintext: Vec<u8> = vec![42; 65536];
1046
1047        // Encrypt the plaintext just like the example code in the docs.
1048        let mut encrypted = vec![];
1049        {
1050            let mut w = Stream::encrypt(PayloadKey([7; 32].into()), &mut encrypted);
1051            w.write_all(&plaintext).unwrap();
1052            w.finish().unwrap();
1053        };
1054
1055        // Seek to the end of the plaintext before decrypting.
1056        let mut reader = Stream::decrypt(PayloadKey([7; 32].into()), Cursor::new(&encrypted));
1057        reader.seek(SeekFrom::End(0)).unwrap();
1058
1059        // Reading should return no bytes, because we're already at EOF.
1060        let mut buf = Vec::new();
1061        reader.read_to_end(&mut buf).unwrap();
1062        assert_eq!(buf.len(), 0);
1063    }
1064}