1#![allow(clippy::all)]
26#![allow(unexpected_cfgs)]
27
28use age_core::secrecy::{ExposeSecret, SecretSlice};
31use chacha20poly1305::{
32 ChaCha20Poly1305,
33 aead::{Aead, KeyInit, KeySizeUser, generic_array::GenericArray},
34};
35use pin_project::pin_project;
36use std::cmp;
37use std::io::{self, Read, Seek, SeekFrom, Write};
38use zeroize::Zeroize;
39
40#[cfg(feature = "async")]
41use futures::{
42 io::{AsyncRead, AsyncWrite, Error},
43 ready,
44 task::{Context, Poll},
45};
46#[cfg(feature = "async")]
47use std::pin::Pin;
48
49const CHUNK_SIZE: usize = 64 * 1024;
50const TAG_SIZE: usize = 16;
51const ENCRYPTED_CHUNK_SIZE: usize = CHUNK_SIZE + TAG_SIZE;
52
53pub struct PayloadKey(pub GenericArray<u8, <ChaCha20Poly1305 as KeySizeUser>::KeySize>);
54
55impl Drop for PayloadKey {
56 fn drop(&mut self) {
57 self.0.as_mut_slice().zeroize();
58 }
59}
60
61#[derive(Clone, Copy, Default)]
66struct Nonce(u128);
67
68impl Nonce {
69 fn set_counter(&mut self, val: u64) {
71 self.0 = u128::from(val) << 8;
72 }
73
74 fn increment_counter(&mut self) {
75 self.0 += 1 << 8;
77 if self.0 >> (8 * 12) != 0 {
78 panic!("We overflowed the nonce!");
79 }
80 }
81
82 fn is_last(&self) -> bool {
83 self.0 & 1 != 0
84 }
85
86 fn set_last(&mut self, last: bool) -> Result<(), ()> {
87 if !self.is_last() {
88 self.0 |= u128::from(last);
89 Ok(())
90 } else {
91 Err(())
92 }
93 }
94
95 fn to_bytes(self) -> [u8; 12] {
96 self.0.to_be_bytes()[4..]
97 .try_into()
98 .expect("slice is correct length")
99 }
100}
101
102#[cfg(feature = "async")]
103#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
104struct EncryptedChunk {
105 bytes: Vec<u8>,
106 offset: usize,
107}
108
109pub struct Stream {
117 aead: ChaCha20Poly1305,
118 nonce: Nonce,
119}
120
121impl Stream {
122 fn new(key: PayloadKey) -> Self {
123 Stream {
124 aead: ChaCha20Poly1305::new(&key.0),
125 nonce: Nonce::default(),
126 }
127 }
128
129 pub fn encrypt<W: Write>(key: PayloadKey, inner: W) -> StreamWriter<W> {
137 StreamWriter {
138 stream: Self::new(key),
139 inner,
140 chunk: Vec::with_capacity(CHUNK_SIZE),
141 #[cfg(feature = "async")]
142 encrypted_chunk: None,
143 }
144 }
145
146 #[cfg(feature = "async")]
154 #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
155 pub fn encrypt_async<W: AsyncWrite>(key: PayloadKey, inner: W) -> StreamWriter<W> {
156 StreamWriter {
157 stream: Self::new(key),
158 inner,
159 chunk: Vec::with_capacity(CHUNK_SIZE),
160 encrypted_chunk: None,
161 }
162 }
163
164 pub fn decrypt<R: Read>(key: PayloadKey, inner: R) -> StreamReader<R> {
172 StreamReader {
173 stream: Self::new(key),
174 inner,
175 encrypted_chunk: vec![0; ENCRYPTED_CHUNK_SIZE],
176 encrypted_pos: 0,
177 start: StartPos::Implicit(0),
178 plaintext_len: None,
179 cur_plaintext_pos: 0,
180 chunk: None,
181 }
182 }
183
184 #[cfg(feature = "async")]
192 #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
193 pub fn decrypt_async<R: AsyncRead>(key: PayloadKey, inner: R) -> StreamReader<R> {
194 StreamReader {
195 stream: Self::new(key),
196 inner,
197 encrypted_chunk: vec![0; ENCRYPTED_CHUNK_SIZE],
198 encrypted_pos: 0,
199 start: StartPos::Implicit(0),
200 plaintext_len: None,
201 cur_plaintext_pos: 0,
202 chunk: None,
203 }
204 }
205
206 fn encrypt_chunk(&mut self, chunk: &[u8], last: bool) -> io::Result<Vec<u8>> {
207 assert!(chunk.len() <= CHUNK_SIZE);
208
209 self.nonce.set_last(last).map_err(|_| {
210 io::Error::new(io::ErrorKind::WriteZero, "last chunk has been processed")
211 })?;
212
213 let encrypted = self
214 .aead
215 .encrypt(&self.nonce.to_bytes().into(), chunk)
216 .expect("we will never hit chacha20::MAX_BLOCKS because of the chunk size");
217 self.nonce.increment_counter();
218
219 Ok(encrypted)
220 }
221
222 fn decrypt_chunk(&mut self, chunk: &[u8], last: bool) -> io::Result<SecretSlice<u8>> {
223 assert!(chunk.len() <= ENCRYPTED_CHUNK_SIZE);
224
225 self.nonce.set_last(last).map_err(|_| {
226 io::Error::new(io::ErrorKind::InvalidData, "last chunk has been processed")
227 })?;
228
229 let decrypted = self
230 .aead
231 .decrypt(&self.nonce.to_bytes().into(), chunk)
232 .map(SecretSlice::from)
233 .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "decryption error"))?;
234 self.nonce.increment_counter();
235
236 Ok(decrypted)
237 }
238
239 fn is_complete(&self) -> bool {
240 self.nonce.is_last()
241 }
242}
243
244#[pin_project(project = StreamWriterProj)]
246pub struct StreamWriter<W> {
247 stream: Stream,
248 #[pin]
249 inner: W,
250 chunk: Vec<u8>,
251 #[cfg(feature = "async")]
252 #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
253 encrypted_chunk: Option<EncryptedChunk>,
254}
255
256impl<W: Write> StreamWriter<W> {
257 pub fn finish(mut self) -> io::Result<W> {
263 let encrypted = self.stream.encrypt_chunk(&self.chunk, true)?;
264 self.inner.write_all(&encrypted)?;
265 Ok(self.inner)
266 }
267}
268
269impl<W: Write> Write for StreamWriter<W> {
270 fn write(&mut self, mut buf: &[u8]) -> io::Result<usize> {
271 let mut bytes_written = 0;
272
273 while !buf.is_empty() {
274 let to_write = cmp::min(CHUNK_SIZE - self.chunk.len(), buf.len());
275 self.chunk.extend_from_slice(&buf[..to_write]);
276 bytes_written += to_write;
277 buf = &buf[to_write..];
278
279 assert!(buf.is_empty() || self.chunk.len() == CHUNK_SIZE);
281
282 if !buf.is_empty() {
285 let encrypted = self.stream.encrypt_chunk(&self.chunk, false)?;
286 self.inner.write_all(&encrypted)?;
287 self.chunk.clear();
288 }
289 }
290
291 Ok(bytes_written)
292 }
293
294 fn flush(&mut self) -> io::Result<()> {
295 self.inner.flush()
296 }
297}
298
299#[cfg(feature = "async")]
300#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
301impl<W: AsyncWrite> StreamWriter<W> {
302 fn poll_flush_chunk(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
303 let StreamWriterProj {
304 mut inner,
305 encrypted_chunk,
306 ..
307 } = self.project();
308
309 if let Some(chunk) = encrypted_chunk {
310 loop {
311 chunk.offset +=
312 ready!(inner.as_mut().poll_write(cx, &chunk.bytes[chunk.offset..]))?;
313 if chunk.offset == chunk.bytes.len() {
314 break;
315 }
316 }
317 }
318 *encrypted_chunk = None;
319
320 Poll::Ready(Ok(()))
321 }
322}
323
324#[cfg(feature = "async")]
325#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
326impl<W: AsyncWrite> AsyncWrite for StreamWriter<W> {
327 fn poll_write(
328 mut self: Pin<&mut Self>,
329 cx: &mut Context,
330 mut buf: &[u8],
331 ) -> Poll<io::Result<usize>> {
332 if buf.is_empty() {
334 return Poll::Ready(Ok(0));
335 }
336
337 loop {
338 ready!(self.as_mut().poll_flush_chunk(cx))?;
339
340 let to_write = cmp::min(CHUNK_SIZE - self.chunk.len(), buf.len());
349
350 self.as_mut()
351 .project()
352 .chunk
353 .extend_from_slice(&buf[..to_write]);
354 buf = &buf[to_write..];
355
356 assert!(buf.is_empty() || self.chunk.len() == CHUNK_SIZE);
358
359 if !buf.is_empty() {
362 let this = self.as_mut().project();
363 *this.encrypted_chunk = Some(EncryptedChunk {
364 bytes: this.stream.encrypt_chunk(this.chunk, false)?,
365 offset: 0,
366 });
367 this.chunk.clear();
368 }
369
370 if to_write > 0 {
372 return Poll::Ready(Ok(to_write));
373 }
374
375 }
383 }
384
385 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
386 ready!(self.as_mut().poll_flush_chunk(cx))?;
387 self.project().inner.poll_flush(cx)
388 }
389
390 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
391 ready!(self.as_mut().poll_flush_chunk(cx))?;
393
394 if !self.stream.is_complete() {
395 let this = self.as_mut().project();
397 *this.encrypted_chunk = Some(EncryptedChunk {
398 bytes: this.stream.encrypt_chunk(this.chunk, true)?,
399 offset: 0,
400 });
401 }
402
403 ready!(self.as_mut().poll_flush_chunk(cx))?;
405 self.project().inner.poll_close(cx)
406 }
407}
408
409enum StartPos {
418 Implicit(u64),
420 Explicit(u64),
422}
423
424#[pin_project]
426pub struct StreamReader<R> {
427 stream: Stream,
428 #[pin]
429 inner: R,
430 encrypted_chunk: Vec<u8>,
431 encrypted_pos: usize,
432 start: StartPos,
433 plaintext_len: Option<u64>,
434 cur_plaintext_pos: u64,
435 chunk: Option<SecretSlice<u8>>,
436}
437
438impl<R> StreamReader<R> {
439 fn count_bytes(&mut self, read: usize) {
440 if let StartPos::Implicit(offset) = &mut self.start {
442 *offset += read as u64;
443 }
444 }
445
446 fn decrypt_chunk(&mut self) -> io::Result<()> {
447 self.count_bytes(self.encrypted_pos);
448 let chunk = &self.encrypted_chunk[..self.encrypted_pos];
449
450 if chunk.is_empty() {
451 if !self.stream.is_complete() {
452 return Err(io::Error::new(
454 io::ErrorKind::UnexpectedEof,
455 "age file is truncated",
456 ));
457 }
458 } else {
459 let last = chunk.len() < ENCRYPTED_CHUNK_SIZE;
463
464 self.chunk = match (self.stream.decrypt_chunk(chunk, last), last) {
465 (Ok(chunk), _)
466 if chunk.expose_secret().is_empty() && self.cur_plaintext_pos > 0 =>
467 {
468 assert!(last);
469 return Err(io::Error::new(
470 io::ErrorKind::InvalidData,
471 "last chunk is empty",
472 ));
473 }
474 (Ok(chunk), _) => Some(chunk),
475 (Err(_), false) => Some(self.stream.decrypt_chunk(chunk, true)?),
476 (Err(e), true) => return Err(e),
477 };
478 }
479
480 self.encrypted_pos = 0;
482
483 Ok(())
484 }
485
486 fn read_from_chunk(&mut self, buf: &mut [u8]) -> usize {
487 if self.chunk.is_none() {
488 return 0;
489 }
490
491 let chunk = self.chunk.as_ref().unwrap();
492 let cur_chunk_offset = self.cur_plaintext_pos as usize % CHUNK_SIZE;
493
494 let to_read = cmp::min(chunk.expose_secret().len() - cur_chunk_offset, buf.len());
495
496 buf[..to_read]
497 .copy_from_slice(&chunk.expose_secret()[cur_chunk_offset..cur_chunk_offset + to_read]);
498 self.cur_plaintext_pos += to_read as u64;
499 if self.cur_plaintext_pos % CHUNK_SIZE as u64 == 0 {
500 self.chunk = None;
502 }
503
504 to_read
505 }
506}
507
508impl<R: Read> Read for StreamReader<R> {
509 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
510 if self.chunk.is_none() {
511 while self.encrypted_pos < ENCRYPTED_CHUNK_SIZE {
512 match self
513 .inner
514 .read(&mut self.encrypted_chunk[self.encrypted_pos..])
515 {
516 Ok(0) => break,
517 Ok(n) => self.encrypted_pos += n,
518 Err(e) => match e.kind() {
519 io::ErrorKind::Interrupted => (),
520 _ => return Err(e),
521 },
522 }
523 }
524 self.decrypt_chunk()?;
525 }
526
527 Ok(self.read_from_chunk(buf))
528 }
529}
530
531#[cfg(feature = "async")]
532#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
533impl<R: AsyncRead + Unpin> AsyncRead for StreamReader<R> {
534 fn poll_read(
535 mut self: Pin<&mut Self>,
536 cx: &mut Context,
537 buf: &mut [u8],
538 ) -> Poll<Result<usize, Error>> {
539 if self.chunk.is_none() {
540 while self.encrypted_pos < ENCRYPTED_CHUNK_SIZE {
541 let this = self.as_mut().project();
542 match ready!(
543 this.inner
544 .poll_read(cx, &mut this.encrypted_chunk[*this.encrypted_pos..])
545 ) {
546 Ok(0) => break,
547 Ok(n) => self.encrypted_pos += n,
548 Err(e) => match e.kind() {
549 io::ErrorKind::Interrupted => (),
550 _ => return Poll::Ready(Err(e)),
551 },
552 }
553 }
554 self.decrypt_chunk()?;
555 }
556
557 Poll::Ready(Ok(self.read_from_chunk(buf)))
558 }
559}
560
561impl<R: Read + Seek> StreamReader<R> {
562 fn start(&mut self) -> io::Result<u64> {
563 match self.start {
564 StartPos::Implicit(offset) => {
565 let current = self.inner.stream_position()?;
566 let start = current - offset;
567
568 self.start = StartPos::Explicit(start);
570
571 Ok(start)
572 }
573 StartPos::Explicit(start) => Ok(start),
574 }
575 }
576
577 fn len(&mut self) -> io::Result<u64> {
579 match self.plaintext_len {
580 None => {
581 let cur_pos = self.inner.stream_position()?;
584 let cur_nonce = self.stream.nonce.0;
585 let ct_start = self.start()?;
586 let ct_end = self.inner.seek(SeekFrom::End(0))?;
587 let ct_len = ct_end - ct_start;
588
589 let num_chunks =
591 (ct_len + (ENCRYPTED_CHUNK_SIZE as u64 - 1)) / ENCRYPTED_CHUNK_SIZE as u64;
592
593 let last_chunk_start = ct_start + ((num_chunks - 1) * ENCRYPTED_CHUNK_SIZE as u64);
596 let mut last_chunk = Vec::with_capacity((ct_end - last_chunk_start) as usize);
597 self.inner.seek(SeekFrom::Start(last_chunk_start))?;
598 self.inner.read_to_end(&mut last_chunk)?;
599 self.stream.nonce.set_counter(num_chunks - 1);
600 self.stream.decrypt_chunk(&last_chunk, true).map_err(|_| {
601 io::Error::new(
602 io::ErrorKind::InvalidData,
603 "Last chunk is invalid, stream might be truncated",
604 )
605 })?;
606
607 let total_tag_size = num_chunks * TAG_SIZE as u64;
610 let pt_len = ct_len - total_tag_size;
611
612 self.inner.seek(SeekFrom::Start(cur_pos))?;
614 self.stream.nonce = Nonce(cur_nonce);
615
616 self.plaintext_len = Some(pt_len);
618
619 Ok(pt_len)
620 }
621 Some(pt_len) => Ok(pt_len),
622 }
623 }
624}
625
626impl<R: Read + Seek> Seek for StreamReader<R> {
627 fn seek(&mut self, pos: SeekFrom) -> io::Result<u64> {
628 let start = self.start()?;
630 let target_pos = match pos {
631 SeekFrom::Start(offset) => offset,
632 SeekFrom::Current(offset) => {
633 let res = (self.cur_plaintext_pos as i64) + offset;
634 if res >= 0 {
635 res as u64
636 } else {
637 return Err(io::Error::new(
638 io::ErrorKind::InvalidData,
639 "cannot seek before the start",
640 ));
641 }
642 }
643 SeekFrom::End(offset) => {
644 let res = (self.len()? as i64) + offset;
645 if res >= 0 {
646 res as u64
647 } else {
648 return Err(io::Error::new(
649 io::ErrorKind::InvalidData,
650 "cannot seek before the start",
651 ));
652 }
653 }
654 };
655
656 let cur_chunk_index = self.cur_plaintext_pos / CHUNK_SIZE as u64;
657
658 let target_chunk_index = target_pos / CHUNK_SIZE as u64;
659 let target_chunk_offset = target_pos % CHUNK_SIZE as u64;
660
661 if target_chunk_index == cur_chunk_index {
662 self.cur_plaintext_pos = target_pos;
664 } else {
665 self.chunk = None;
667
668 self.inner.seek(SeekFrom::Start(
670 start + (target_chunk_index * ENCRYPTED_CHUNK_SIZE as u64),
671 ))?;
672 self.stream.nonce.set_counter(target_chunk_index);
673 self.cur_plaintext_pos = target_chunk_index * CHUNK_SIZE as u64;
674
675 if target_chunk_offset > 0 {
677 let mut to_drop = vec![0; target_chunk_offset as usize];
678 self.read_exact(&mut to_drop)?;
679 }
680 else if target_pos == self.len()? {
690 self.stream
691 .nonce
692 .set_last(true)
693 .expect("We unset the last chunk flag earlier");
694 }
695 }
696
697 Ok(target_pos)
699 }
700}
701
702#[cfg(test)]
703mod tests {
704 use age_core::secrecy::ExposeSecret;
705 use std::io::{self, Cursor, Read, Seek, SeekFrom, Write};
706
707 use super::{CHUNK_SIZE, PayloadKey, Stream};
708
709 #[cfg(feature = "async")]
710 use futures::{
711 io::{AsyncRead, AsyncWrite},
712 pin_mut,
713 task::Poll,
714 };
715 #[cfg(feature = "async")]
716 use futures_test::task::noop_context;
717
718 #[test]
719 fn chunk_round_trip() {
720 let data = vec![42; CHUNK_SIZE];
721
722 let encrypted = {
723 let mut s = Stream::new(PayloadKey([7; 32].into()));
724 s.encrypt_chunk(&data, false).unwrap()
725 };
726
727 let decrypted = {
728 let mut s = Stream::new(PayloadKey([7; 32].into()));
729 s.decrypt_chunk(&encrypted, false).unwrap()
730 };
731
732 assert_eq!(decrypted.expose_secret(), &data);
733 }
734
735 #[test]
736 fn last_chunk_round_trip() {
737 let data = vec![42; CHUNK_SIZE];
738
739 let encrypted = {
740 let mut s = Stream::new(PayloadKey([7; 32].into()));
741 let res = s.encrypt_chunk(&data, true).unwrap();
742
743 assert_eq!(
745 s.encrypt_chunk(&data, false).unwrap_err().kind(),
746 io::ErrorKind::WriteZero
747 );
748 assert_eq!(
749 s.encrypt_chunk(&data, true).unwrap_err().kind(),
750 io::ErrorKind::WriteZero
751 );
752
753 res
754 };
755
756 let decrypted = {
757 let mut s = Stream::new(PayloadKey([7; 32].into()));
758 let res = s.decrypt_chunk(&encrypted, true).unwrap();
759
760 match s.decrypt_chunk(&encrypted, false) {
762 Err(e) => assert_eq!(e.kind(), io::ErrorKind::InvalidData),
763 _ => panic!("Expected error"),
764 }
765 match s.decrypt_chunk(&encrypted, true) {
766 Err(e) => assert_eq!(e.kind(), io::ErrorKind::InvalidData),
767 _ => panic!("Expected error"),
768 }
769
770 res
771 };
772
773 assert_eq!(decrypted.expose_secret(), &data);
774 }
775
776 fn stream_round_trip(data: &[u8]) {
777 let mut encrypted = vec![];
778 {
779 let mut w = Stream::encrypt(PayloadKey([7; 32].into()), &mut encrypted);
780 w.write_all(data).unwrap();
781 w.finish().unwrap();
782 };
783
784 let decrypted = {
785 let mut buf = vec![];
786 let mut r = Stream::decrypt(PayloadKey([7; 32].into()), &encrypted[..]);
787 r.read_to_end(&mut buf).unwrap();
788 buf
789 };
790
791 assert_eq!(decrypted, data);
792 }
793
794 #[test]
798 fn stream_round_trip_empty() {
799 stream_round_trip(&[]);
800 }
801
802 #[test]
803 fn stream_round_trip_short() {
804 stream_round_trip(&[42; 1024]);
805 }
806
807 #[test]
808 fn stream_round_trip_chunk() {
809 stream_round_trip(&[42; CHUNK_SIZE]);
810 }
811
812 #[test]
813 fn stream_round_trip_long() {
814 stream_round_trip(&[42; 100 * 1024]);
815 }
816
817 #[cfg(feature = "async")]
818 fn stream_async_round_trip(data: &[u8]) {
819 let mut encrypted = vec![];
820 {
821 let w = Stream::encrypt_async(PayloadKey([7; 32].into()), &mut encrypted);
822 pin_mut!(w);
823
824 let mut cx = noop_context();
825
826 let mut tmp = data;
827 loop {
828 match w.as_mut().poll_write(&mut cx, tmp) {
829 Poll::Ready(Ok(0)) => break,
830 Poll::Ready(Ok(written)) => tmp = &tmp[written..],
831 Poll::Ready(Err(e)) => panic!("Unexpected error: {}", e),
832 Poll::Pending => panic!("Unexpected Pending"),
833 }
834 }
835 loop {
836 match w.as_mut().poll_close(&mut cx) {
837 Poll::Ready(Ok(())) => break,
838 Poll::Ready(Err(e)) => panic!("Unexpected error: {}", e),
839 Poll::Pending => panic!("Unexpected Pending"),
840 }
841 }
842 };
843
844 let decrypted = {
845 let mut buf = vec![];
846 let r = Stream::decrypt_async(PayloadKey([7; 32].into()), &encrypted[..]);
847 pin_mut!(r);
848
849 let mut cx = noop_context();
850
851 let mut tmp = [0; 4096];
852 loop {
853 match r.as_mut().poll_read(&mut cx, &mut tmp) {
854 Poll::Ready(Ok(0)) => break buf,
855 Poll::Ready(Ok(read)) => buf.extend_from_slice(&tmp[..read]),
856 Poll::Ready(Err(e)) => panic!("Unexpected error: {}", e),
857 Poll::Pending => panic!("Unexpected Pending"),
858 }
859 }
860 };
861
862 assert_eq!(decrypted, data);
863 }
864
865 #[cfg(feature = "async")]
866 #[test]
867 fn stream_async_round_trip_short() {
868 stream_async_round_trip(&[42; 1024]);
869 }
870
871 #[cfg(feature = "async")]
872 #[test]
873 fn stream_async_round_trip_chunk() {
874 stream_async_round_trip(&[42; CHUNK_SIZE]);
875 }
876
877 #[cfg(feature = "async")]
878 #[test]
879 fn stream_async_round_trip_long() {
880 stream_async_round_trip(&[42; 100 * 1024]);
881 }
882
883 #[cfg(feature = "async")]
884 fn stream_async_io_copy(data: &[u8]) {
885 use futures::AsyncWriteExt;
886
887 let runtime = tokio::runtime::Builder::new_current_thread()
888 .build()
889 .unwrap();
890 let mut encrypted = vec![];
891 let result = runtime.block_on(async {
892 let mut w = Stream::encrypt_async(PayloadKey([7; 32].into()), &mut encrypted);
893 match futures::io::copy(data, &mut w).await {
894 Ok(written) => {
895 w.close().await.unwrap();
896 Ok(written)
897 }
898 Err(e) => Err(e),
899 }
900 });
901
902 match result {
903 Ok(written) => assert_eq!(written, data.len() as u64),
904 Err(e) => panic!("Unexpected error: {}", e),
905 }
906
907 let decrypted = {
908 let mut buf = vec![];
909 let result = runtime.block_on(async {
910 let r = Stream::decrypt_async(PayloadKey([7; 32].into()), &encrypted[..]);
911 futures::io::copy(r, &mut buf).await
912 });
913
914 match result {
915 Ok(written) => assert_eq!(written, data.len() as u64),
916 Err(e) => panic!("Unexpected error: {}", e),
917 }
918
919 buf
920 };
921
922 assert_eq!(decrypted, data);
923 }
924
925 #[cfg(feature = "async")]
926 #[test]
927 fn stream_async_io_copy_short() {
928 stream_async_io_copy(&[42; 1024]);
929 }
930
931 #[cfg(feature = "async")]
932 #[test]
933 fn stream_async_io_copy_chunk() {
934 stream_async_io_copy(&[42; CHUNK_SIZE]);
935 }
936
937 #[cfg(feature = "async")]
938 #[test]
939 fn stream_async_io_copy_long() {
940 stream_async_io_copy(&[42; 100 * 1024]);
941 }
942
943 #[test]
944 fn stream_fails_to_decrypt_truncated_file() {
945 let data = vec![42; 2 * CHUNK_SIZE];
946
947 let mut encrypted = vec![];
948 {
949 let mut w = Stream::encrypt(PayloadKey([7; 32].into()), &mut encrypted);
950 w.write_all(&data).unwrap();
951 };
953
954 let mut buf = vec![];
955 let mut r = Stream::decrypt(PayloadKey([7; 32].into()), &encrypted[..]);
956 assert_eq!(
957 r.read_to_end(&mut buf).unwrap_err().kind(),
958 io::ErrorKind::UnexpectedEof
959 );
960 }
961
962 #[test]
963 fn stream_seeking() {
964 let mut data = vec![0; 100 * 1024];
965 for (i, b) in data.iter_mut().enumerate() {
966 *b = i as u8;
967 }
968
969 let mut encrypted = vec![];
970 {
971 let mut w = Stream::encrypt(PayloadKey([7; 32].into()), &mut encrypted);
972 w.write_all(&data).unwrap();
973 w.finish().unwrap();
974 };
975
976 let mut r = Stream::decrypt(PayloadKey([7; 32].into()), Cursor::new(encrypted));
977
978 let mut buf = vec![0; 100];
980 for i in 0..700 {
981 r.read_exact(&mut buf).unwrap();
982 assert_eq!(&buf[..], &data[100 * i..100 * (i + 1)]);
983 }
984
985 r.seek(SeekFrom::Start(250)).unwrap();
987 r.read_exact(&mut buf).unwrap();
988 assert_eq!(&buf[..], &data[250..350]);
989
990 r.seek(SeekFrom::Current(510)).unwrap();
992 r.read_exact(&mut buf).unwrap();
993 assert_eq!(&buf[..], &data[860..960]);
994
995 r.seek(SeekFrom::End(-1337)).unwrap();
997 r.read_exact(&mut buf).unwrap();
998 assert_eq!(&buf[..], &data[data.len() - 1337..data.len() - 1237]);
999 }
1000
1001 #[test]
1002 fn seek_from_end_fails_on_truncation() {
1003 let mut plaintext: Vec<u8> = b"hello".to_vec();
1006 plaintext.extend_from_slice(&[0; 65536]);
1007
1008 let mut encrypted = vec![];
1010 {
1011 let mut w = Stream::encrypt(PayloadKey([7; 32].into()), &mut encrypted);
1012 w.write_all(&plaintext).unwrap();
1013 w.finish().unwrap();
1014 };
1015
1016 let mut reader = Stream::decrypt(PayloadKey([7; 32].into()), Cursor::new(&encrypted));
1020 let eof_relative_offset = 1_i64 - plaintext.len() as i64;
1021 reader.seek(SeekFrom::End(eof_relative_offset)).unwrap();
1022 let mut buf = [0; 4];
1023 reader.read_exact(&mut buf).unwrap();
1024 assert_eq!(&buf, b"ello", "This is correct.");
1025
1026 let truncated_ciphertext = &encrypted[..encrypted.len() - 1];
1030 let mut truncated_reader = Stream::decrypt(
1031 PayloadKey([7; 32].into()),
1032 Cursor::new(truncated_ciphertext),
1033 );
1034 match truncated_reader.seek(SeekFrom::End(eof_relative_offset)) {
1036 Err(e) => {
1037 assert_eq!(e.kind(), io::ErrorKind::InvalidData);
1038 assert_eq!(
1039 &e.to_string(),
1040 "Last chunk is invalid, stream might be truncated",
1041 );
1042 }
1043 Ok(_) => panic!("This is a security issue."),
1044 }
1045 }
1046
1047 #[test]
1048 fn seek_from_end_with_exact_chunk() {
1049 let plaintext: Vec<u8> = vec![42; 65536];
1050
1051 let mut encrypted = vec![];
1053 {
1054 let mut w = Stream::encrypt(PayloadKey([7; 32].into()), &mut encrypted);
1055 w.write_all(&plaintext).unwrap();
1056 w.finish().unwrap();
1057 };
1058
1059 let mut reader = Stream::decrypt(PayloadKey([7; 32].into()), Cursor::new(&encrypted));
1061 reader.seek(SeekFrom::End(0)).unwrap();
1062
1063 let mut buf = Vec::new();
1065 reader.read_to_end(&mut buf).unwrap();
1066 assert_eq!(buf.len(), 0);
1067 }
1068}