darkbio_crypto/stream/
mod.rs

1// The MIT License (MIT)
2//
3// Copyright (c) 2019 Jack Grigg
4//
5// Permission is hereby granted, free of charge, to any person obtaining a copy
6// of this software and associated documentation files (the "Software"), to deal
7// in the Software without restriction, including without limitation the rights
8// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9// copies of the Software, and to permit persons to whom the Software is
10// furnished to do so, subject to the following conditions:
11//
12// The above copyright notice and this permission notice shall be included in
13// all copies or substantial portions of the Software.
14//
15// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21// THE SOFTWARE.
22//
23// Source: https://github.com/str4d/rage/blob/main/age/src/primitives/stream.rs
24
25#![allow(clippy::all)]
26#![allow(unexpected_cfgs)]
27
28//! I/O helper structs for age file encryption and decryption.
29
30use age_core::secrecy::{ExposeSecret, SecretSlice};
31use chacha20poly1305::{
32    ChaCha20Poly1305,
33    aead::{Aead, KeyInit, KeySizeUser, generic_array::GenericArray},
34};
35use pin_project::pin_project;
36use std::cmp;
37use std::io::{self, Read, Seek, SeekFrom, Write};
38use zeroize::Zeroize;
39
40#[cfg(feature = "async")]
41use futures::{
42    io::{AsyncRead, AsyncWrite, Error},
43    ready,
44    task::{Context, Poll},
45};
46#[cfg(feature = "async")]
47use std::pin::Pin;
48
49const CHUNK_SIZE: usize = 64 * 1024;
50const TAG_SIZE: usize = 16;
51const ENCRYPTED_CHUNK_SIZE: usize = CHUNK_SIZE + TAG_SIZE;
52
53pub struct PayloadKey(pub GenericArray<u8, <ChaCha20Poly1305 as KeySizeUser>::KeySize>);
54
55impl Drop for PayloadKey {
56    fn drop(&mut self) {
57        self.0.as_mut_slice().zeroize();
58    }
59}
60
61/// The nonce used in age's STREAM encryption.
62///
63/// Structured as an 11 bytes of big endian counter, and 1 byte of last block flag
64/// (`0x00 / 0x01`). We store this in the lower 12 bytes of a `u128`.
65#[derive(Clone, Copy, Default)]
66struct Nonce(u128);
67
68impl Nonce {
69    /// Unsets last-chunk flag.
70    fn set_counter(&mut self, val: u64) {
71        self.0 = u128::from(val) << 8;
72    }
73
74    fn increment_counter(&mut self) {
75        // Increment the 11-byte counter
76        self.0 += 1 << 8;
77        if self.0 >> (8 * 12) != 0 {
78            panic!("We overflowed the nonce!");
79        }
80    }
81
82    fn is_last(&self) -> bool {
83        self.0 & 1 != 0
84    }
85
86    fn set_last(&mut self, last: bool) -> Result<(), ()> {
87        if !self.is_last() {
88            self.0 |= u128::from(last);
89            Ok(())
90        } else {
91            Err(())
92        }
93    }
94
95    fn to_bytes(self) -> [u8; 12] {
96        self.0.to_be_bytes()[4..]
97            .try_into()
98            .expect("slice is correct length")
99    }
100}
101
102#[cfg(feature = "async")]
103#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
104struct EncryptedChunk {
105    bytes: Vec<u8>,
106    offset: usize,
107}
108
109/// `STREAM[key](plaintext)`
110///
111/// The [STREAM] construction for online authenticated encryption, instantiated with
112/// ChaCha20-Poly1305 in 64KiB chunks, and a nonce structure of 11 bytes of big endian
113/// counter, and 1 byte of last block flag (0x00 / 0x01).
114///
115/// [STREAM]: https://eprint.iacr.org/2015/189.pdf
116pub struct Stream {
117    aead: ChaCha20Poly1305,
118    nonce: Nonce,
119}
120
121impl Stream {
122    fn new(key: PayloadKey) -> Self {
123        Stream {
124            aead: ChaCha20Poly1305::new(&key.0),
125            nonce: Nonce::default(),
126        }
127    }
128
129    /// Wraps `STREAM` encryption under the given `key` around a writer.
130    ///
131    /// `key` must **never** be repeated across multiple streams. In `age` this is
132    /// achieved by deriving the key with [`HKDF`] from both a random file key and a
133    /// random nonce.
134    ///
135    /// [`HKDF`]: age_core::primitives::hkdf
136    pub fn encrypt<W: Write>(key: PayloadKey, inner: W) -> StreamWriter<W> {
137        StreamWriter {
138            stream: Self::new(key),
139            inner,
140            chunk: Vec::with_capacity(CHUNK_SIZE),
141            #[cfg(feature = "async")]
142            encrypted_chunk: None,
143        }
144    }
145
146    /// Wraps `STREAM` encryption under the given `key` around a writer.
147    ///
148    /// `key` must **never** be repeated across multiple streams. In `age` this is
149    /// achieved by deriving the key with [`HKDF`] from both a random file key and a
150    /// random nonce.
151    ///
152    /// [`HKDF`]: age_core::primitives::hkdf
153    #[cfg(feature = "async")]
154    #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
155    pub fn encrypt_async<W: AsyncWrite>(key: PayloadKey, inner: W) -> StreamWriter<W> {
156        StreamWriter {
157            stream: Self::new(key),
158            inner,
159            chunk: Vec::with_capacity(CHUNK_SIZE),
160            encrypted_chunk: None,
161        }
162    }
163
164    /// Wraps `STREAM` decryption under the given `key` around a reader.
165    ///
166    /// `key` must **never** be repeated across multiple streams. In `age` this is
167    /// achieved by deriving the key with [`HKDF`] from both a random file key and a
168    /// random nonce.
169    ///
170    /// [`HKDF`]: age_core::primitives::hkdf
171    pub fn decrypt<R: Read>(key: PayloadKey, inner: R) -> StreamReader<R> {
172        StreamReader {
173            stream: Self::new(key),
174            inner,
175            encrypted_chunk: vec![0; ENCRYPTED_CHUNK_SIZE],
176            encrypted_pos: 0,
177            start: StartPos::Implicit(0),
178            plaintext_len: None,
179            cur_plaintext_pos: 0,
180            chunk: None,
181        }
182    }
183
184    /// Wraps `STREAM` decryption under the given `key` around a reader.
185    ///
186    /// `key` must **never** be repeated across multiple streams. In `age` this is
187    /// achieved by deriving the key with [`HKDF`] from both a random file key and a
188    /// random nonce.
189    ///
190    /// [`HKDF`]: age_core::primitives::hkdf
191    #[cfg(feature = "async")]
192    #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
193    pub fn decrypt_async<R: AsyncRead>(key: PayloadKey, inner: R) -> StreamReader<R> {
194        StreamReader {
195            stream: Self::new(key),
196            inner,
197            encrypted_chunk: vec![0; ENCRYPTED_CHUNK_SIZE],
198            encrypted_pos: 0,
199            start: StartPos::Implicit(0),
200            plaintext_len: None,
201            cur_plaintext_pos: 0,
202            chunk: None,
203        }
204    }
205
206    fn encrypt_chunk(&mut self, chunk: &[u8], last: bool) -> io::Result<Vec<u8>> {
207        assert!(chunk.len() <= CHUNK_SIZE);
208
209        self.nonce.set_last(last).map_err(|_| {
210            io::Error::new(io::ErrorKind::WriteZero, "last chunk has been processed")
211        })?;
212
213        let encrypted = self
214            .aead
215            .encrypt(&self.nonce.to_bytes().into(), chunk)
216            .expect("we will never hit chacha20::MAX_BLOCKS because of the chunk size");
217        self.nonce.increment_counter();
218
219        Ok(encrypted)
220    }
221
222    fn decrypt_chunk(&mut self, chunk: &[u8], last: bool) -> io::Result<SecretSlice<u8>> {
223        assert!(chunk.len() <= ENCRYPTED_CHUNK_SIZE);
224
225        self.nonce.set_last(last).map_err(|_| {
226            io::Error::new(io::ErrorKind::InvalidData, "last chunk has been processed")
227        })?;
228
229        let decrypted = self
230            .aead
231            .decrypt(&self.nonce.to_bytes().into(), chunk)
232            .map(SecretSlice::from)
233            .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "decryption error"))?;
234        self.nonce.increment_counter();
235
236        Ok(decrypted)
237    }
238
239    fn is_complete(&self) -> bool {
240        self.nonce.is_last()
241    }
242}
243
244/// Writes an encrypted age file.
245#[pin_project(project = StreamWriterProj)]
246pub struct StreamWriter<W> {
247    stream: Stream,
248    #[pin]
249    inner: W,
250    chunk: Vec<u8>,
251    #[cfg(feature = "async")]
252    #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
253    encrypted_chunk: Option<EncryptedChunk>,
254}
255
256impl<W: Write> StreamWriter<W> {
257    /// Writes the final chunk of the age file.
258    ///
259    /// You **MUST** call `finish` when you are done writing, in order to finish the
260    /// encryption process. Failing to call `finish` will result in a truncated file that
261    /// that will fail to decrypt.
262    pub fn finish(mut self) -> io::Result<W> {
263        let encrypted = self.stream.encrypt_chunk(&self.chunk, true)?;
264        self.inner.write_all(&encrypted)?;
265        Ok(self.inner)
266    }
267}
268
269impl<W: Write> Write for StreamWriter<W> {
270    fn write(&mut self, mut buf: &[u8]) -> io::Result<usize> {
271        let mut bytes_written = 0;
272
273        while !buf.is_empty() {
274            let to_write = cmp::min(CHUNK_SIZE - self.chunk.len(), buf.len());
275            self.chunk.extend_from_slice(&buf[..to_write]);
276            bytes_written += to_write;
277            buf = &buf[to_write..];
278
279            // At this point, either buf is empty, or we have a full chunk.
280            assert!(buf.is_empty() || self.chunk.len() == CHUNK_SIZE);
281
282            // Only encrypt the chunk if we have more data to write, as the last
283            // chunk must be written in finish().
284            if !buf.is_empty() {
285                let encrypted = self.stream.encrypt_chunk(&self.chunk, false)?;
286                self.inner.write_all(&encrypted)?;
287                self.chunk.clear();
288            }
289        }
290
291        Ok(bytes_written)
292    }
293
294    fn flush(&mut self) -> io::Result<()> {
295        self.inner.flush()
296    }
297}
298
299#[cfg(feature = "async")]
300#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
301impl<W: AsyncWrite> StreamWriter<W> {
302    fn poll_flush_chunk(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
303        let StreamWriterProj {
304            mut inner,
305            encrypted_chunk,
306            ..
307        } = self.project();
308
309        if let Some(chunk) = encrypted_chunk {
310            loop {
311                chunk.offset +=
312                    ready!(inner.as_mut().poll_write(cx, &chunk.bytes[chunk.offset..]))?;
313                if chunk.offset == chunk.bytes.len() {
314                    break;
315                }
316            }
317        }
318        *encrypted_chunk = None;
319
320        Poll::Ready(Ok(()))
321    }
322}
323
324#[cfg(feature = "async")]
325#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
326impl<W: AsyncWrite> AsyncWrite for StreamWriter<W> {
327    fn poll_write(
328        mut self: Pin<&mut Self>,
329        cx: &mut Context,
330        mut buf: &[u8],
331    ) -> Poll<io::Result<usize>> {
332        // If the buffer is empty, return immediately
333        if buf.is_empty() {
334            return Poll::Ready(Ok(0));
335        }
336
337        loop {
338            ready!(self.as_mut().poll_flush_chunk(cx))?;
339
340            // We can encounter one of three cases here:
341            // 1. `0 < buf.len() <= CHUNK_SIZE - self.chunk.len()`: we append to the
342            //    partial chunk and return. This may happen to complete the chunk.
343            // 2. `0 < CHUNK_SIZE - self.chunk.len() < buf.len()`: we consume part of
344            //    `buf` to complete the chunk, encrypt it, and return.
345            // 3. `0 == CHUNK_SIZE - self.chunk.len() < buf.len()`: we hit case 1 in a
346            //    previous invocation. We encrypt the chunk, and then loop around (where
347            //    we are guaranteed to hit case 1).
348            let to_write = cmp::min(CHUNK_SIZE - self.chunk.len(), buf.len());
349
350            self.as_mut()
351                .project()
352                .chunk
353                .extend_from_slice(&buf[..to_write]);
354            buf = &buf[to_write..];
355
356            // At this point, either buf is empty, or we have a full chunk.
357            assert!(buf.is_empty() || self.chunk.len() == CHUNK_SIZE);
358
359            // Only encrypt the chunk if we have more data to write, as the last
360            // chunk must be written in poll_close().
361            if !buf.is_empty() {
362                let this = self.as_mut().project();
363                *this.encrypted_chunk = Some(EncryptedChunk {
364                    bytes: this.stream.encrypt_chunk(this.chunk, false)?,
365                    offset: 0,
366                });
367                this.chunk.clear();
368            }
369
370            // If we wrote some data, return how much we wrote
371            if to_write > 0 {
372                return Poll::Ready(Ok(to_write));
373            }
374
375            // If we didn't write any data, loop and write some, to ensure
376            // this function does not return 0. This enables compatibility with
377            // futures::io::copy() and tokio::io::copy(), which will return a
378            // WriteZero error in that case.
379            // Since those functions copy 8K at a time, and CHUNK_SIZE is
380            // a multiple of 8K, this ends up happening once for each chunk
381            // after the first one
382        }
383    }
384
385    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
386        ready!(self.as_mut().poll_flush_chunk(cx))?;
387        self.project().inner.poll_flush(cx)
388    }
389
390    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
391        // Flush any remaining encrypted chunk bytes.
392        ready!(self.as_mut().poll_flush_chunk(cx))?;
393
394        if !self.stream.is_complete() {
395            // Finish the stream.
396            let this = self.as_mut().project();
397            *this.encrypted_chunk = Some(EncryptedChunk {
398                bytes: this.stream.encrypt_chunk(this.chunk, true)?,
399                offset: 0,
400            });
401        }
402
403        // Flush the final chunk (if we didn't in the first call).
404        ready!(self.as_mut().poll_flush_chunk(cx))?;
405        self.project().inner.poll_close(cx)
406    }
407}
408
409/// The position in the underlying reader corresponding to the start of the stream.
410///
411/// To impl Seek for StreamReader, we need to know the point in the reader corresponding
412/// to the first byte of the stream. But we can't query the reader for its current
413/// position without having a specific constructor for `R: Read + Seek`, which makes the
414/// higher-level API more complex. Instead, we count the number of bytes that have been
415/// read from the reader until we first need to seek, and then inside `impl Seek` we can
416/// query the reader's current position and figure out where the start was.
417enum StartPos {
418    /// An offset that we can subtract from the current position.
419    Implicit(u64),
420    /// The precise start position.
421    Explicit(u64),
422}
423
424/// Provides access to a decrypted age file.
425#[pin_project]
426pub struct StreamReader<R> {
427    stream: Stream,
428    #[pin]
429    inner: R,
430    encrypted_chunk: Vec<u8>,
431    encrypted_pos: usize,
432    start: StartPos,
433    plaintext_len: Option<u64>,
434    cur_plaintext_pos: u64,
435    chunk: Option<SecretSlice<u8>>,
436}
437
438impl<R> StreamReader<R> {
439    fn count_bytes(&mut self, read: usize) {
440        // We only need to count if we haven't yet worked out the start position.
441        if let StartPos::Implicit(offset) = &mut self.start {
442            *offset += read as u64;
443        }
444    }
445
446    fn decrypt_chunk(&mut self) -> io::Result<()> {
447        self.count_bytes(self.encrypted_pos);
448        let chunk = &self.encrypted_chunk[..self.encrypted_pos];
449
450        if chunk.is_empty() {
451            if !self.stream.is_complete() {
452                // Stream has ended before seeing the last chunk.
453                return Err(io::Error::new(
454                    io::ErrorKind::UnexpectedEof,
455                    "age file is truncated",
456                ));
457            }
458        } else {
459            // This check works for all cases except when the age file is an integer
460            // multiple of the chunk size. In that case, we try decrypting twice on a
461            // decryption failure.
462            let last = chunk.len() < ENCRYPTED_CHUNK_SIZE;
463
464            self.chunk = match (self.stream.decrypt_chunk(chunk, last), last) {
465                (Ok(chunk), _)
466                    if chunk.expose_secret().is_empty() && self.cur_plaintext_pos > 0 =>
467                {
468                    assert!(last);
469                    return Err(io::Error::new(
470                        io::ErrorKind::InvalidData,
471                        "last chunk is empty",
472                    ));
473                }
474                (Ok(chunk), _) => Some(chunk),
475                (Err(_), false) => Some(self.stream.decrypt_chunk(chunk, true)?),
476                (Err(e), true) => return Err(e),
477            };
478        }
479
480        // We've finished with this encrypted chunk.
481        self.encrypted_pos = 0;
482
483        Ok(())
484    }
485
486    fn read_from_chunk(&mut self, buf: &mut [u8]) -> usize {
487        if self.chunk.is_none() {
488            return 0;
489        }
490
491        let chunk = self.chunk.as_ref().unwrap();
492        let cur_chunk_offset = self.cur_plaintext_pos as usize % CHUNK_SIZE;
493
494        let to_read = cmp::min(chunk.expose_secret().len() - cur_chunk_offset, buf.len());
495
496        buf[..to_read]
497            .copy_from_slice(&chunk.expose_secret()[cur_chunk_offset..cur_chunk_offset + to_read]);
498        self.cur_plaintext_pos += to_read as u64;
499        if self.cur_plaintext_pos % CHUNK_SIZE as u64 == 0 {
500            // We've finished with the current chunk.
501            self.chunk = None;
502        }
503
504        to_read
505    }
506}
507
508impl<R: Read> Read for StreamReader<R> {
509    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
510        if self.chunk.is_none() {
511            while self.encrypted_pos < ENCRYPTED_CHUNK_SIZE {
512                match self
513                    .inner
514                    .read(&mut self.encrypted_chunk[self.encrypted_pos..])
515                {
516                    Ok(0) => break,
517                    Ok(n) => self.encrypted_pos += n,
518                    Err(e) => match e.kind() {
519                        io::ErrorKind::Interrupted => (),
520                        _ => return Err(e),
521                    },
522                }
523            }
524            self.decrypt_chunk()?;
525        }
526
527        Ok(self.read_from_chunk(buf))
528    }
529}
530
531#[cfg(feature = "async")]
532#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
533impl<R: AsyncRead + Unpin> AsyncRead for StreamReader<R> {
534    fn poll_read(
535        mut self: Pin<&mut Self>,
536        cx: &mut Context,
537        buf: &mut [u8],
538    ) -> Poll<Result<usize, Error>> {
539        if self.chunk.is_none() {
540            while self.encrypted_pos < ENCRYPTED_CHUNK_SIZE {
541                let this = self.as_mut().project();
542                match ready!(
543                    this.inner
544                        .poll_read(cx, &mut this.encrypted_chunk[*this.encrypted_pos..])
545                ) {
546                    Ok(0) => break,
547                    Ok(n) => self.encrypted_pos += n,
548                    Err(e) => match e.kind() {
549                        io::ErrorKind::Interrupted => (),
550                        _ => return Poll::Ready(Err(e)),
551                    },
552                }
553            }
554            self.decrypt_chunk()?;
555        }
556
557        Poll::Ready(Ok(self.read_from_chunk(buf)))
558    }
559}
560
561impl<R: Read + Seek> StreamReader<R> {
562    fn start(&mut self) -> io::Result<u64> {
563        match self.start {
564            StartPos::Implicit(offset) => {
565                let current = self.inner.stream_position()?;
566                let start = current - offset;
567
568                // Cache the start for future calls.
569                self.start = StartPos::Explicit(start);
570
571                Ok(start)
572            }
573            StartPos::Explicit(start) => Ok(start),
574        }
575    }
576
577    /// Returns the length of the plaintext
578    fn len(&mut self) -> io::Result<u64> {
579        match self.plaintext_len {
580            None => {
581                // Cache the current position and nonce, and then grab the start and end
582                // ciphertext positions.
583                let cur_pos = self.inner.stream_position()?;
584                let cur_nonce = self.stream.nonce.0;
585                let ct_start = self.start()?;
586                let ct_end = self.inner.seek(SeekFrom::End(0))?;
587                let ct_len = ct_end - ct_start;
588
589                // Use ceiling division to determine the number of chunks.
590                let num_chunks =
591                    (ct_len + (ENCRYPTED_CHUNK_SIZE as u64 - 1)) / ENCRYPTED_CHUNK_SIZE as u64;
592
593                // Authenticate the ciphertext length by checking that we can successfully
594                // decrypt the last chunk _as_ a last chunk.
595                let last_chunk_start = ct_start + ((num_chunks - 1) * ENCRYPTED_CHUNK_SIZE as u64);
596                let mut last_chunk = Vec::with_capacity((ct_end - last_chunk_start) as usize);
597                self.inner.seek(SeekFrom::Start(last_chunk_start))?;
598                self.inner.read_to_end(&mut last_chunk)?;
599                self.stream.nonce.set_counter(num_chunks - 1);
600                self.stream.decrypt_chunk(&last_chunk, true).map_err(|_| {
601                    io::Error::new(
602                        io::ErrorKind::InvalidData,
603                        "Last chunk is invalid, stream might be truncated",
604                    )
605                })?;
606
607                // Now that we have authenticated the ciphertext length, we can use it to
608                // calculate the plaintext length.
609                let total_tag_size = num_chunks * TAG_SIZE as u64;
610                let pt_len = ct_len - total_tag_size;
611
612                // Return to the original position and restore the nonce.
613                self.inner.seek(SeekFrom::Start(cur_pos))?;
614                self.stream.nonce = Nonce(cur_nonce);
615
616                // Cache the length for future calls.
617                self.plaintext_len = Some(pt_len);
618
619                Ok(pt_len)
620            }
621            Some(pt_len) => Ok(pt_len),
622        }
623    }
624}
625
626impl<R: Read + Seek> Seek for StreamReader<R> {
627    fn seek(&mut self, pos: SeekFrom) -> io::Result<u64> {
628        // Convert the offset into the target position within the plaintext
629        let start = self.start()?;
630        let target_pos = match pos {
631            SeekFrom::Start(offset) => offset,
632            SeekFrom::Current(offset) => {
633                let res = (self.cur_plaintext_pos as i64) + offset;
634                if res >= 0 {
635                    res as u64
636                } else {
637                    return Err(io::Error::new(
638                        io::ErrorKind::InvalidData,
639                        "cannot seek before the start",
640                    ));
641                }
642            }
643            SeekFrom::End(offset) => {
644                let res = (self.len()? as i64) + offset;
645                if res >= 0 {
646                    res as u64
647                } else {
648                    return Err(io::Error::new(
649                        io::ErrorKind::InvalidData,
650                        "cannot seek before the start",
651                    ));
652                }
653            }
654        };
655
656        let cur_chunk_index = self.cur_plaintext_pos / CHUNK_SIZE as u64;
657
658        let target_chunk_index = target_pos / CHUNK_SIZE as u64;
659        let target_chunk_offset = target_pos % CHUNK_SIZE as u64;
660
661        if target_chunk_index == cur_chunk_index {
662            // We just need to reposition ourselves within the current chunk.
663            self.cur_plaintext_pos = target_pos;
664        } else {
665            // Clear the current chunk
666            self.chunk = None;
667
668            // Seek to the beginning of the target chunk
669            self.inner.seek(SeekFrom::Start(
670                start + (target_chunk_index * ENCRYPTED_CHUNK_SIZE as u64),
671            ))?;
672            self.stream.nonce.set_counter(target_chunk_index);
673            self.cur_plaintext_pos = target_chunk_index * CHUNK_SIZE as u64;
674
675            // Read and drop bytes from the chunk to reach the target position.
676            if target_chunk_offset > 0 {
677                let mut to_drop = vec![0; target_chunk_offset as usize];
678                self.read_exact(&mut to_drop)?;
679            }
680            // We need to handle the edge case where the last chunk is not short, and
681            // `target_pos == self.len()` (so `target_chunk_index` points to the chunk
682            // after the last chunk). The next read would return no bytes, but because
683            // `target_chunk_offset == 0` we weren't forced to read past any in-chunk
684            // bytes, and thus have not set the `last` flag on the nonce.
685            //
686            // To handle this edge case, when `target_pos` is a multiple of the chunk
687            // size (i.e. this conditional branch), we compute the length of the
688            // plaintext. This is cached, so the overhead should be minimal.
689            else if target_pos == self.len()? {
690                self.stream
691                    .nonce
692                    .set_last(true)
693                    .expect("We unset the last chunk flag earlier");
694            }
695        }
696
697        // All done!
698        Ok(target_pos)
699    }
700}
701
702#[cfg(test)]
703mod tests {
704    use age_core::secrecy::ExposeSecret;
705    use std::io::{self, Cursor, Read, Seek, SeekFrom, Write};
706
707    use super::{CHUNK_SIZE, PayloadKey, Stream};
708
709    #[cfg(feature = "async")]
710    use futures::{
711        io::{AsyncRead, AsyncWrite},
712        pin_mut,
713        task::Poll,
714    };
715    #[cfg(feature = "async")]
716    use futures_test::task::noop_context;
717
718    #[test]
719    fn chunk_round_trip() {
720        let data = vec![42; CHUNK_SIZE];
721
722        let encrypted = {
723            let mut s = Stream::new(PayloadKey([7; 32].into()));
724            s.encrypt_chunk(&data, false).unwrap()
725        };
726
727        let decrypted = {
728            let mut s = Stream::new(PayloadKey([7; 32].into()));
729            s.decrypt_chunk(&encrypted, false).unwrap()
730        };
731
732        assert_eq!(decrypted.expose_secret(), &data);
733    }
734
735    #[test]
736    fn last_chunk_round_trip() {
737        let data = vec![42; CHUNK_SIZE];
738
739        let encrypted = {
740            let mut s = Stream::new(PayloadKey([7; 32].into()));
741            let res = s.encrypt_chunk(&data, true).unwrap();
742
743            // Further calls return an error
744            assert_eq!(
745                s.encrypt_chunk(&data, false).unwrap_err().kind(),
746                io::ErrorKind::WriteZero
747            );
748            assert_eq!(
749                s.encrypt_chunk(&data, true).unwrap_err().kind(),
750                io::ErrorKind::WriteZero
751            );
752
753            res
754        };
755
756        let decrypted = {
757            let mut s = Stream::new(PayloadKey([7; 32].into()));
758            let res = s.decrypt_chunk(&encrypted, true).unwrap();
759
760            // Further calls return an error
761            match s.decrypt_chunk(&encrypted, false) {
762                Err(e) => assert_eq!(e.kind(), io::ErrorKind::InvalidData),
763                _ => panic!("Expected error"),
764            }
765            match s.decrypt_chunk(&encrypted, true) {
766                Err(e) => assert_eq!(e.kind(), io::ErrorKind::InvalidData),
767                _ => panic!("Expected error"),
768            }
769
770            res
771        };
772
773        assert_eq!(decrypted.expose_secret(), &data);
774    }
775
776    fn stream_round_trip(data: &[u8]) {
777        let mut encrypted = vec![];
778        {
779            let mut w = Stream::encrypt(PayloadKey([7; 32].into()), &mut encrypted);
780            w.write_all(data).unwrap();
781            w.finish().unwrap();
782        };
783
784        let decrypted = {
785            let mut buf = vec![];
786            let mut r = Stream::decrypt(PayloadKey([7; 32].into()), &encrypted[..]);
787            r.read_to_end(&mut buf).unwrap();
788            buf
789        };
790
791        assert_eq!(decrypted, data);
792    }
793
794    /// Check that we can encrypt an empty slice.
795    ///
796    /// This is the sole exception to the "last chunk must be non-empty" rule.
797    #[test]
798    fn stream_round_trip_empty() {
799        stream_round_trip(&[]);
800    }
801
802    #[test]
803    fn stream_round_trip_short() {
804        stream_round_trip(&[42; 1024]);
805    }
806
807    #[test]
808    fn stream_round_trip_chunk() {
809        stream_round_trip(&[42; CHUNK_SIZE]);
810    }
811
812    #[test]
813    fn stream_round_trip_long() {
814        stream_round_trip(&[42; 100 * 1024]);
815    }
816
817    #[cfg(feature = "async")]
818    fn stream_async_round_trip(data: &[u8]) {
819        let mut encrypted = vec![];
820        {
821            let w = Stream::encrypt_async(PayloadKey([7; 32].into()), &mut encrypted);
822            pin_mut!(w);
823
824            let mut cx = noop_context();
825
826            let mut tmp = data;
827            loop {
828                match w.as_mut().poll_write(&mut cx, tmp) {
829                    Poll::Ready(Ok(0)) => break,
830                    Poll::Ready(Ok(written)) => tmp = &tmp[written..],
831                    Poll::Ready(Err(e)) => panic!("Unexpected error: {}", e),
832                    Poll::Pending => panic!("Unexpected Pending"),
833                }
834            }
835            loop {
836                match w.as_mut().poll_close(&mut cx) {
837                    Poll::Ready(Ok(())) => break,
838                    Poll::Ready(Err(e)) => panic!("Unexpected error: {}", e),
839                    Poll::Pending => panic!("Unexpected Pending"),
840                }
841            }
842        };
843
844        let decrypted = {
845            let mut buf = vec![];
846            let r = Stream::decrypt_async(PayloadKey([7; 32].into()), &encrypted[..]);
847            pin_mut!(r);
848
849            let mut cx = noop_context();
850
851            let mut tmp = [0; 4096];
852            loop {
853                match r.as_mut().poll_read(&mut cx, &mut tmp) {
854                    Poll::Ready(Ok(0)) => break buf,
855                    Poll::Ready(Ok(read)) => buf.extend_from_slice(&tmp[..read]),
856                    Poll::Ready(Err(e)) => panic!("Unexpected error: {}", e),
857                    Poll::Pending => panic!("Unexpected Pending"),
858                }
859            }
860        };
861
862        assert_eq!(decrypted, data);
863    }
864
865    #[cfg(feature = "async")]
866    #[test]
867    fn stream_async_round_trip_short() {
868        stream_async_round_trip(&[42; 1024]);
869    }
870
871    #[cfg(feature = "async")]
872    #[test]
873    fn stream_async_round_trip_chunk() {
874        stream_async_round_trip(&[42; CHUNK_SIZE]);
875    }
876
877    #[cfg(feature = "async")]
878    #[test]
879    fn stream_async_round_trip_long() {
880        stream_async_round_trip(&[42; 100 * 1024]);
881    }
882
883    #[cfg(feature = "async")]
884    fn stream_async_io_copy(data: &[u8]) {
885        use futures::AsyncWriteExt;
886
887        let runtime = tokio::runtime::Builder::new_current_thread()
888            .build()
889            .unwrap();
890        let mut encrypted = vec![];
891        let result = runtime.block_on(async {
892            let mut w = Stream::encrypt_async(PayloadKey([7; 32].into()), &mut encrypted);
893            match futures::io::copy(data, &mut w).await {
894                Ok(written) => {
895                    w.close().await.unwrap();
896                    Ok(written)
897                }
898                Err(e) => Err(e),
899            }
900        });
901
902        match result {
903            Ok(written) => assert_eq!(written, data.len() as u64),
904            Err(e) => panic!("Unexpected error: {}", e),
905        }
906
907        let decrypted = {
908            let mut buf = vec![];
909            let result = runtime.block_on(async {
910                let r = Stream::decrypt_async(PayloadKey([7; 32].into()), &encrypted[..]);
911                futures::io::copy(r, &mut buf).await
912            });
913
914            match result {
915                Ok(written) => assert_eq!(written, data.len() as u64),
916                Err(e) => panic!("Unexpected error: {}", e),
917            }
918
919            buf
920        };
921
922        assert_eq!(decrypted, data);
923    }
924
925    #[cfg(feature = "async")]
926    #[test]
927    fn stream_async_io_copy_short() {
928        stream_async_io_copy(&[42; 1024]);
929    }
930
931    #[cfg(feature = "async")]
932    #[test]
933    fn stream_async_io_copy_chunk() {
934        stream_async_io_copy(&[42; CHUNK_SIZE]);
935    }
936
937    #[cfg(feature = "async")]
938    #[test]
939    fn stream_async_io_copy_long() {
940        stream_async_io_copy(&[42; 100 * 1024]);
941    }
942
943    #[test]
944    fn stream_fails_to_decrypt_truncated_file() {
945        let data = vec![42; 2 * CHUNK_SIZE];
946
947        let mut encrypted = vec![];
948        {
949            let mut w = Stream::encrypt(PayloadKey([7; 32].into()), &mut encrypted);
950            w.write_all(&data).unwrap();
951            // Forget to call w.finish()!
952        };
953
954        let mut buf = vec![];
955        let mut r = Stream::decrypt(PayloadKey([7; 32].into()), &encrypted[..]);
956        assert_eq!(
957            r.read_to_end(&mut buf).unwrap_err().kind(),
958            io::ErrorKind::UnexpectedEof
959        );
960    }
961
962    #[test]
963    fn stream_seeking() {
964        let mut data = vec![0; 100 * 1024];
965        for (i, b) in data.iter_mut().enumerate() {
966            *b = i as u8;
967        }
968
969        let mut encrypted = vec![];
970        {
971            let mut w = Stream::encrypt(PayloadKey([7; 32].into()), &mut encrypted);
972            w.write_all(&data).unwrap();
973            w.finish().unwrap();
974        };
975
976        let mut r = Stream::decrypt(PayloadKey([7; 32].into()), Cursor::new(encrypted));
977
978        // Read through into the second chunk
979        let mut buf = vec![0; 100];
980        for i in 0..700 {
981            r.read_exact(&mut buf).unwrap();
982            assert_eq!(&buf[..], &data[100 * i..100 * (i + 1)]);
983        }
984
985        // Seek back into the first chunk
986        r.seek(SeekFrom::Start(250)).unwrap();
987        r.read_exact(&mut buf).unwrap();
988        assert_eq!(&buf[..], &data[250..350]);
989
990        // Seek forwards within this chunk
991        r.seek(SeekFrom::Current(510)).unwrap();
992        r.read_exact(&mut buf).unwrap();
993        assert_eq!(&buf[..], &data[860..960]);
994
995        // Seek backwards from the end
996        r.seek(SeekFrom::End(-1337)).unwrap();
997        r.read_exact(&mut buf).unwrap();
998        assert_eq!(&buf[..], &data[data.len() - 1337..data.len() - 1237]);
999    }
1000
1001    #[test]
1002    fn seek_from_end_fails_on_truncation() {
1003        // The plaintext is the string "hello" followed by 65536 zeros, just enough to
1004        // give us some bytes to play with in the second chunk.
1005        let mut plaintext: Vec<u8> = b"hello".to_vec();
1006        plaintext.extend_from_slice(&[0; 65536]);
1007
1008        // Encrypt the plaintext just like the example code in the docs.
1009        let mut encrypted = vec![];
1010        {
1011            let mut w = Stream::encrypt(PayloadKey([7; 32].into()), &mut encrypted);
1012            w.write_all(&plaintext).unwrap();
1013            w.finish().unwrap();
1014        };
1015
1016        // First check the correct behavior of seeks relative to EOF. Create a decrypting
1017        // reader, and move it one byte forward from the start, using SeekFrom::End.
1018        // Confirm that reading 4 bytes from that point gives us "ello", as it should.
1019        let mut reader = Stream::decrypt(PayloadKey([7; 32].into()), Cursor::new(&encrypted));
1020        let eof_relative_offset = 1_i64 - plaintext.len() as i64;
1021        reader.seek(SeekFrom::End(eof_relative_offset)).unwrap();
1022        let mut buf = [0; 4];
1023        reader.read_exact(&mut buf).unwrap();
1024        assert_eq!(&buf, b"ello", "This is correct.");
1025
1026        // Do the same thing again, except this time truncate the ciphertext by one byte
1027        // first. This should cause some sort of error, instead of a successful read that
1028        // returns the wrong plaintext.
1029        let truncated_ciphertext = &encrypted[..encrypted.len() - 1];
1030        let mut truncated_reader = Stream::decrypt(
1031            PayloadKey([7; 32].into()),
1032            Cursor::new(truncated_ciphertext),
1033        );
1034        // Use the same seek target as above.
1035        match truncated_reader.seek(SeekFrom::End(eof_relative_offset)) {
1036            Err(e) => {
1037                assert_eq!(e.kind(), io::ErrorKind::InvalidData);
1038                assert_eq!(
1039                    &e.to_string(),
1040                    "Last chunk is invalid, stream might be truncated",
1041                );
1042            }
1043            Ok(_) => panic!("This is a security issue."),
1044        }
1045    }
1046
1047    #[test]
1048    fn seek_from_end_with_exact_chunk() {
1049        let plaintext: Vec<u8> = vec![42; 65536];
1050
1051        // Encrypt the plaintext just like the example code in the docs.
1052        let mut encrypted = vec![];
1053        {
1054            let mut w = Stream::encrypt(PayloadKey([7; 32].into()), &mut encrypted);
1055            w.write_all(&plaintext).unwrap();
1056            w.finish().unwrap();
1057        };
1058
1059        // Seek to the end of the plaintext before decrypting.
1060        let mut reader = Stream::decrypt(PayloadKey([7; 32].into()), Cursor::new(&encrypted));
1061        reader.seek(SeekFrom::End(0)).unwrap();
1062
1063        // Reading should return no bytes, because we're already at EOF.
1064        let mut buf = Vec::new();
1065        reader.read_to_end(&mut buf).unwrap();
1066        assert_eq!(buf.len(), 0);
1067    }
1068}