1use anubis_core::secrecy::{ExposeSecret, SecretSlice};
4use chacha20poly1305::{
5 aead::{generic_array::GenericArray, Aead, KeyInit, KeySizeUser},
6 ChaCha20Poly1305,
7};
8use pin_project::pin_project;
9use std::cmp;
10use std::io::{self, Read, Seek, SeekFrom, Write};
11use zeroize::Zeroize;
12
13#[cfg(feature = "async")]
14use futures::{
15 io::{AsyncRead, AsyncWrite, Error},
16 ready,
17 task::{Context, Poll},
18};
19#[cfg(feature = "async")]
20use std::pin::Pin;
21
22const CHUNK_SIZE: usize = 64 * 1024;
23const TAG_SIZE: usize = 16;
24const ENCRYPTED_CHUNK_SIZE: usize = CHUNK_SIZE + TAG_SIZE;
25
26pub(crate) struct PayloadKey(
27 pub(crate) GenericArray<u8, <ChaCha20Poly1305 as KeySizeUser>::KeySize>,
28);
29
30impl Drop for PayloadKey {
31 fn drop(&mut self) {
32 self.0.as_mut_slice().zeroize();
33 }
34}
35
36#[derive(Clone, Copy, Default)]
41struct Nonce(u128);
42
43impl Nonce {
44 fn set_counter(&mut self, val: u64) {
46 self.0 = u128::from(val) << 8;
47 }
48
49 fn increment_counter(&mut self) {
50 self.0 += 1 << 8;
52 if self.0 >> (8 * 12) != 0 {
53 panic!("We overflowed the nonce!");
54 }
55 }
56
57 fn is_last(&self) -> bool {
58 self.0 & 1 != 0
59 }
60
61 fn set_last(&mut self, last: bool) -> Result<(), ()> {
62 if !self.is_last() {
63 self.0 |= u128::from(last);
64 Ok(())
65 } else {
66 Err(())
67 }
68 }
69
70 fn to_bytes(self) -> [u8; 12] {
71 self.0.to_be_bytes()[4..]
72 .try_into()
73 .expect("slice is correct length")
74 }
75}
76
77#[cfg(feature = "async")]
78#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
79struct EncryptedChunk {
80 bytes: Vec<u8>,
81 offset: usize,
82}
83
84pub(crate) struct Stream {
92 aead: ChaCha20Poly1305,
93 nonce: Nonce,
94}
95
96impl Stream {
97 fn new(key: PayloadKey) -> Self {
98 Stream {
99 aead: ChaCha20Poly1305::new(&key.0),
100 nonce: Nonce::default(),
101 }
102 }
103
104 pub(crate) fn encrypt<W: Write>(key: PayloadKey, inner: W) -> StreamWriter<W> {
112 StreamWriter {
113 stream: Self::new(key),
114 inner,
115 chunk: Vec::with_capacity(CHUNK_SIZE),
116 #[cfg(feature = "async")]
117 encrypted_chunk: None,
118 }
119 }
120
121 #[cfg(feature = "async")]
129 #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
130 pub(crate) fn encrypt_async<W: AsyncWrite>(key: PayloadKey, inner: W) -> StreamWriter<W> {
131 StreamWriter {
132 stream: Self::new(key),
133 inner,
134 chunk: Vec::with_capacity(CHUNK_SIZE),
135 encrypted_chunk: None,
136 }
137 }
138
139 pub(crate) fn decrypt<R: Read>(key: PayloadKey, inner: R) -> StreamReader<R> {
147 StreamReader {
148 stream: Self::new(key),
149 inner,
150 encrypted_chunk: vec![0; ENCRYPTED_CHUNK_SIZE],
151 encrypted_pos: 0,
152 start: StartPos::Implicit(0),
153 plaintext_len: None,
154 cur_plaintext_pos: 0,
155 chunk: None,
156 }
157 }
158
159 #[cfg(feature = "async")]
167 #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
168 pub(crate) fn decrypt_async<R: AsyncRead>(key: PayloadKey, inner: R) -> StreamReader<R> {
169 StreamReader {
170 stream: Self::new(key),
171 inner,
172 encrypted_chunk: vec![0; ENCRYPTED_CHUNK_SIZE],
173 encrypted_pos: 0,
174 start: StartPos::Implicit(0),
175 plaintext_len: None,
176 cur_plaintext_pos: 0,
177 chunk: None,
178 }
179 }
180
181 fn encrypt_chunk(&mut self, chunk: &[u8], last: bool) -> io::Result<Vec<u8>> {
182 assert!(chunk.len() <= CHUNK_SIZE);
183
184 self.nonce.set_last(last).map_err(|_| {
185 io::Error::new(io::ErrorKind::WriteZero, "last chunk has been processed")
186 })?;
187
188 let encrypted = self
189 .aead
190 .encrypt(&self.nonce.to_bytes().into(), chunk)
191 .expect("we will never hit chacha20::MAX_BLOCKS because of the chunk size");
192 self.nonce.increment_counter();
193
194 Ok(encrypted)
195 }
196
197 fn decrypt_chunk(&mut self, chunk: &[u8], last: bool) -> io::Result<SecretSlice<u8>> {
198 assert!(chunk.len() <= ENCRYPTED_CHUNK_SIZE);
199
200 self.nonce.set_last(last).map_err(|_| {
201 io::Error::new(io::ErrorKind::InvalidData, "last chunk has been processed")
202 })?;
203
204 let decrypted = self
205 .aead
206 .decrypt(&self.nonce.to_bytes().into(), chunk)
207 .map(SecretSlice::from)
208 .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "decryption error"))?;
209 self.nonce.increment_counter();
210
211 Ok(decrypted)
212 }
213
214 fn is_complete(&self) -> bool {
215 self.nonce.is_last()
216 }
217}
218
219#[pin_project(project = StreamWriterProj)]
221pub struct StreamWriter<W> {
222 stream: Stream,
223 #[pin]
224 inner: W,
225 chunk: Vec<u8>,
226 #[cfg(feature = "async")]
227 #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
228 encrypted_chunk: Option<EncryptedChunk>,
229}
230
231impl<W: Write> StreamWriter<W> {
232 pub fn finish(mut self) -> io::Result<W> {
238 let encrypted = self.stream.encrypt_chunk(&self.chunk, true)?;
239 self.inner.write_all(&encrypted)?;
240 Ok(self.inner)
241 }
242}
243
244impl<W: Write> Write for StreamWriter<W> {
245 fn write(&mut self, mut buf: &[u8]) -> io::Result<usize> {
246 let mut bytes_written = 0;
247
248 while !buf.is_empty() {
249 let to_write = cmp::min(CHUNK_SIZE - self.chunk.len(), buf.len());
250 self.chunk.extend_from_slice(&buf[..to_write]);
251 bytes_written += to_write;
252 buf = &buf[to_write..];
253
254 assert!(buf.is_empty() || self.chunk.len() == CHUNK_SIZE);
256
257 if !buf.is_empty() {
260 let encrypted = self.stream.encrypt_chunk(&self.chunk, false)?;
261 self.inner.write_all(&encrypted)?;
262 self.chunk.clear();
263 }
264 }
265
266 Ok(bytes_written)
267 }
268
269 fn flush(&mut self) -> io::Result<()> {
270 self.inner.flush()
271 }
272}
273
274#[cfg(feature = "async")]
275#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
276impl<W: AsyncWrite> StreamWriter<W> {
277 fn poll_flush_chunk(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
278 let StreamWriterProj {
279 mut inner,
280 encrypted_chunk,
281 ..
282 } = self.project();
283
284 if let Some(chunk) = encrypted_chunk {
285 loop {
286 chunk.offset +=
287 ready!(inner.as_mut().poll_write(cx, &chunk.bytes[chunk.offset..]))?;
288 if chunk.offset == chunk.bytes.len() {
289 break;
290 }
291 }
292 }
293 *encrypted_chunk = None;
294
295 Poll::Ready(Ok(()))
296 }
297}
298
299#[cfg(feature = "async")]
300#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
301impl<W: AsyncWrite> AsyncWrite for StreamWriter<W> {
302 fn poll_write(
303 mut self: Pin<&mut Self>,
304 cx: &mut Context,
305 mut buf: &[u8],
306 ) -> Poll<io::Result<usize>> {
307 if buf.is_empty() {
309 return Poll::Ready(Ok(0));
310 }
311
312 loop {
313 ready!(self.as_mut().poll_flush_chunk(cx))?;
314
315 let to_write = cmp::min(CHUNK_SIZE - self.chunk.len(), buf.len());
324
325 self.as_mut()
326 .project()
327 .chunk
328 .extend_from_slice(&buf[..to_write]);
329 buf = &buf[to_write..];
330
331 assert!(buf.is_empty() || self.chunk.len() == CHUNK_SIZE);
333
334 if !buf.is_empty() {
337 let this = self.as_mut().project();
338 *this.encrypted_chunk = Some(EncryptedChunk {
339 bytes: this.stream.encrypt_chunk(this.chunk, false)?,
340 offset: 0,
341 });
342 this.chunk.clear();
343 }
344
345 if to_write > 0 {
347 return Poll::Ready(Ok(to_write));
348 }
349
350 }
358 }
359
360 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
361 ready!(self.as_mut().poll_flush_chunk(cx))?;
362 self.project().inner.poll_flush(cx)
363 }
364
365 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
366 ready!(self.as_mut().poll_flush_chunk(cx))?;
368
369 if !self.stream.is_complete() {
370 let this = self.as_mut().project();
372 *this.encrypted_chunk = Some(EncryptedChunk {
373 bytes: this.stream.encrypt_chunk(this.chunk, true)?,
374 offset: 0,
375 });
376 }
377
378 ready!(self.as_mut().poll_flush_chunk(cx))?;
380 self.project().inner.poll_close(cx)
381 }
382}
383
384enum StartPos {
393 Implicit(u64),
395 Explicit(u64),
397}
398
399#[pin_project]
401pub struct StreamReader<R> {
402 stream: Stream,
403 #[pin]
404 inner: R,
405 encrypted_chunk: Vec<u8>,
406 encrypted_pos: usize,
407 start: StartPos,
408 plaintext_len: Option<u64>,
409 cur_plaintext_pos: u64,
410 chunk: Option<SecretSlice<u8>>,
411}
412
413impl<R> StreamReader<R> {
414 fn count_bytes(&mut self, read: usize) {
415 if let StartPos::Implicit(offset) = &mut self.start {
417 *offset += read as u64;
418 }
419 }
420
421 fn decrypt_chunk(&mut self) -> io::Result<()> {
422 self.count_bytes(self.encrypted_pos);
423 let chunk = &self.encrypted_chunk[..self.encrypted_pos];
424
425 if chunk.is_empty() {
426 if !self.stream.is_complete() {
427 return Err(io::Error::new(
429 io::ErrorKind::UnexpectedEof,
430 "age file is truncated",
431 ));
432 }
433 } else {
434 let last = chunk.len() < ENCRYPTED_CHUNK_SIZE;
438
439 self.chunk = match (self.stream.decrypt_chunk(chunk, last), last) {
440 (Ok(chunk), _)
441 if chunk.expose_secret().is_empty() && self.cur_plaintext_pos > 0 =>
442 {
443 assert!(last);
444 return Err(io::Error::new(
445 io::ErrorKind::InvalidData,
446 crate::fl!("err-stream-last-chunk-empty"),
447 ));
448 }
449 (Ok(chunk), _) => Some(chunk),
450 (Err(_), false) => Some(self.stream.decrypt_chunk(chunk, true)?),
451 (Err(e), true) => return Err(e),
452 };
453 }
454
455 self.encrypted_pos = 0;
457
458 Ok(())
459 }
460
461 fn read_from_chunk(&mut self, buf: &mut [u8]) -> usize {
462 if self.chunk.is_none() {
463 return 0;
464 }
465
466 let chunk = self.chunk.as_ref().unwrap();
467 let cur_chunk_offset = self.cur_plaintext_pos as usize % CHUNK_SIZE;
468
469 let to_read = cmp::min(chunk.expose_secret().len() - cur_chunk_offset, buf.len());
470
471 buf[..to_read]
472 .copy_from_slice(&chunk.expose_secret()[cur_chunk_offset..cur_chunk_offset + to_read]);
473 self.cur_plaintext_pos += to_read as u64;
474 if self.cur_plaintext_pos % CHUNK_SIZE as u64 == 0 {
475 self.chunk = None;
477 }
478
479 to_read
480 }
481}
482
483impl<R: Read> Read for StreamReader<R> {
484 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
485 if self.chunk.is_none() {
486 while self.encrypted_pos < ENCRYPTED_CHUNK_SIZE {
487 match self
488 .inner
489 .read(&mut self.encrypted_chunk[self.encrypted_pos..])
490 {
491 Ok(0) => break,
492 Ok(n) => self.encrypted_pos += n,
493 Err(e) => match e.kind() {
494 io::ErrorKind::Interrupted => (),
495 _ => return Err(e),
496 },
497 }
498 }
499 self.decrypt_chunk()?;
500 }
501
502 Ok(self.read_from_chunk(buf))
503 }
504}
505
506#[cfg(feature = "async")]
507#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
508impl<R: AsyncRead + Unpin> AsyncRead for StreamReader<R> {
509 fn poll_read(
510 mut self: Pin<&mut Self>,
511 cx: &mut Context,
512 buf: &mut [u8],
513 ) -> Poll<Result<usize, Error>> {
514 if self.chunk.is_none() {
515 while self.encrypted_pos < ENCRYPTED_CHUNK_SIZE {
516 let this = self.as_mut().project();
517 match ready!(this
518 .inner
519 .poll_read(cx, &mut this.encrypted_chunk[*this.encrypted_pos..]))
520 {
521 Ok(0) => break,
522 Ok(n) => self.encrypted_pos += n,
523 Err(e) => match e.kind() {
524 io::ErrorKind::Interrupted => (),
525 _ => return Poll::Ready(Err(e)),
526 },
527 }
528 }
529 self.decrypt_chunk()?;
530 }
531
532 Poll::Ready(Ok(self.read_from_chunk(buf)))
533 }
534}
535
536impl<R: Read + Seek> StreamReader<R> {
537 fn start(&mut self) -> io::Result<u64> {
538 match self.start {
539 StartPos::Implicit(offset) => {
540 let current = self.inner.stream_position()?;
541 let start = current - offset;
542
543 self.start = StartPos::Explicit(start);
545
546 Ok(start)
547 }
548 StartPos::Explicit(start) => Ok(start),
549 }
550 }
551
552 fn len(&mut self) -> io::Result<u64> {
554 match self.plaintext_len {
555 None => {
556 let cur_pos = self.inner.stream_position()?;
559 let cur_nonce = self.stream.nonce.0;
560 let ct_start = self.start()?;
561 let ct_end = self.inner.seek(SeekFrom::End(0))?;
562 let ct_len = ct_end - ct_start;
563
564 let num_chunks =
566 (ct_len + (ENCRYPTED_CHUNK_SIZE as u64 - 1)) / ENCRYPTED_CHUNK_SIZE as u64;
567
568 let last_chunk_start = ct_start + ((num_chunks - 1) * ENCRYPTED_CHUNK_SIZE as u64);
571 let mut last_chunk = Vec::with_capacity((ct_end - last_chunk_start) as usize);
572 self.inner.seek(SeekFrom::Start(last_chunk_start))?;
573 self.inner.read_to_end(&mut last_chunk)?;
574 self.stream.nonce.set_counter(num_chunks - 1);
575 self.stream.decrypt_chunk(&last_chunk, true).map_err(|_| {
576 io::Error::new(
577 io::ErrorKind::InvalidData,
578 "Last chunk is invalid, stream might be truncated",
579 )
580 })?;
581
582 let total_tag_size = num_chunks * TAG_SIZE as u64;
585 let pt_len = ct_len - total_tag_size;
586
587 self.inner.seek(SeekFrom::Start(cur_pos))?;
589 self.stream.nonce = Nonce(cur_nonce);
590
591 self.plaintext_len = Some(pt_len);
593
594 Ok(pt_len)
595 }
596 Some(pt_len) => Ok(pt_len),
597 }
598 }
599}
600
601impl<R: Read + Seek> Seek for StreamReader<R> {
602 fn seek(&mut self, pos: SeekFrom) -> io::Result<u64> {
603 let start = self.start()?;
605 let target_pos = match pos {
606 SeekFrom::Start(offset) => offset,
607 SeekFrom::Current(offset) => {
608 let res = (self.cur_plaintext_pos as i64) + offset;
609 if res >= 0 {
610 res as u64
611 } else {
612 return Err(io::Error::new(
613 io::ErrorKind::InvalidData,
614 "cannot seek before the start",
615 ));
616 }
617 }
618 SeekFrom::End(offset) => {
619 let res = (self.len()? as i64) + offset;
620 if res >= 0 {
621 res as u64
622 } else {
623 return Err(io::Error::new(
624 io::ErrorKind::InvalidData,
625 "cannot seek before the start",
626 ));
627 }
628 }
629 };
630
631 let cur_chunk_index = self.cur_plaintext_pos / CHUNK_SIZE as u64;
632
633 let target_chunk_index = target_pos / CHUNK_SIZE as u64;
634 let target_chunk_offset = target_pos % CHUNK_SIZE as u64;
635
636 if target_chunk_index == cur_chunk_index {
637 self.cur_plaintext_pos = target_pos;
639 } else {
640 self.chunk = None;
642
643 self.inner.seek(SeekFrom::Start(
645 start + (target_chunk_index * ENCRYPTED_CHUNK_SIZE as u64),
646 ))?;
647 self.stream.nonce.set_counter(target_chunk_index);
648 self.cur_plaintext_pos = target_chunk_index * CHUNK_SIZE as u64;
649
650 if target_chunk_offset > 0 {
652 let mut to_drop = vec![0; target_chunk_offset as usize];
653 self.read_exact(&mut to_drop)?;
654 }
655 else if target_pos == self.len()? {
665 self.stream
666 .nonce
667 .set_last(true)
668 .expect("We unset the last chunk flag earlier");
669 }
670 }
671
672 Ok(target_pos)
674 }
675}
676
677#[cfg(test)]
678mod tests {
679 use anubis_core::secrecy::ExposeSecret;
680 use std::io::{self, Cursor, Read, Seek, SeekFrom, Write};
681
682 use super::{PayloadKey, Stream, CHUNK_SIZE};
683
684 #[cfg(feature = "async")]
685 use futures::{
686 io::{AsyncRead, AsyncWrite},
687 pin_mut,
688 task::Poll,
689 };
690 #[cfg(feature = "async")]
691 use futures_test::task::noop_context;
692
693 #[test]
694 fn chunk_round_trip() {
695 let data = vec![42; CHUNK_SIZE];
696
697 let encrypted = {
698 let mut s = Stream::new(PayloadKey([7; 32].into()));
699 s.encrypt_chunk(&data, false).unwrap()
700 };
701
702 let decrypted = {
703 let mut s = Stream::new(PayloadKey([7; 32].into()));
704 s.decrypt_chunk(&encrypted, false).unwrap()
705 };
706
707 assert_eq!(decrypted.expose_secret(), &data);
708 }
709
710 #[test]
711 fn last_chunk_round_trip() {
712 let data = vec![42; CHUNK_SIZE];
713
714 let encrypted = {
715 let mut s = Stream::new(PayloadKey([7; 32].into()));
716 let res = s.encrypt_chunk(&data, true).unwrap();
717
718 assert_eq!(
720 s.encrypt_chunk(&data, false).unwrap_err().kind(),
721 io::ErrorKind::WriteZero
722 );
723 assert_eq!(
724 s.encrypt_chunk(&data, true).unwrap_err().kind(),
725 io::ErrorKind::WriteZero
726 );
727
728 res
729 };
730
731 let decrypted = {
732 let mut s = Stream::new(PayloadKey([7; 32].into()));
733 let res = s.decrypt_chunk(&encrypted, true).unwrap();
734
735 match s.decrypt_chunk(&encrypted, false) {
737 Err(e) => assert_eq!(e.kind(), io::ErrorKind::InvalidData),
738 _ => panic!("Expected error"),
739 }
740 match s.decrypt_chunk(&encrypted, true) {
741 Err(e) => assert_eq!(e.kind(), io::ErrorKind::InvalidData),
742 _ => panic!("Expected error"),
743 }
744
745 res
746 };
747
748 assert_eq!(decrypted.expose_secret(), &data);
749 }
750
751 fn stream_round_trip(data: &[u8]) {
752 let mut encrypted = vec![];
753 {
754 let mut w = Stream::encrypt(PayloadKey([7; 32].into()), &mut encrypted);
755 w.write_all(data).unwrap();
756 w.finish().unwrap();
757 };
758
759 let decrypted = {
760 let mut buf = vec![];
761 let mut r = Stream::decrypt(PayloadKey([7; 32].into()), &encrypted[..]);
762 r.read_to_end(&mut buf).unwrap();
763 buf
764 };
765
766 assert_eq!(decrypted, data);
767 }
768
769 #[test]
773 fn stream_round_trip_empty() {
774 stream_round_trip(&[]);
775 }
776
777 #[test]
778 fn stream_round_trip_short() {
779 stream_round_trip(&[42; 1024]);
780 }
781
782 #[test]
783 fn stream_round_trip_chunk() {
784 stream_round_trip(&[42; CHUNK_SIZE]);
785 }
786
787 #[test]
788 fn stream_round_trip_long() {
789 stream_round_trip(&[42; 100 * 1024]);
790 }
791
792 #[cfg(feature = "async")]
793 fn stream_async_round_trip(data: &[u8]) {
794 let mut encrypted = vec![];
795 {
796 let w = Stream::encrypt_async(PayloadKey([7; 32].into()), &mut encrypted);
797 pin_mut!(w);
798
799 let mut cx = noop_context();
800
801 let mut tmp = data;
802 loop {
803 match w.as_mut().poll_write(&mut cx, tmp) {
804 Poll::Ready(Ok(0)) => break,
805 Poll::Ready(Ok(written)) => tmp = &tmp[written..],
806 Poll::Ready(Err(e)) => panic!("Unexpected error: {}", e),
807 Poll::Pending => panic!("Unexpected Pending"),
808 }
809 }
810 loop {
811 match w.as_mut().poll_close(&mut cx) {
812 Poll::Ready(Ok(())) => break,
813 Poll::Ready(Err(e)) => panic!("Unexpected error: {}", e),
814 Poll::Pending => panic!("Unexpected Pending"),
815 }
816 }
817 };
818
819 let decrypted = {
820 let mut buf = vec![];
821 let r = Stream::decrypt_async(PayloadKey([7; 32].into()), &encrypted[..]);
822 pin_mut!(r);
823
824 let mut cx = noop_context();
825
826 let mut tmp = [0; 4096];
827 loop {
828 match r.as_mut().poll_read(&mut cx, &mut tmp) {
829 Poll::Ready(Ok(0)) => break buf,
830 Poll::Ready(Ok(read)) => buf.extend_from_slice(&tmp[..read]),
831 Poll::Ready(Err(e)) => panic!("Unexpected error: {}", e),
832 Poll::Pending => panic!("Unexpected Pending"),
833 }
834 }
835 };
836
837 assert_eq!(decrypted, data);
838 }
839
840 #[cfg(feature = "async")]
841 #[test]
842 fn stream_async_round_trip_short() {
843 stream_async_round_trip(&[42; 1024]);
844 }
845
846 #[cfg(feature = "async")]
847 #[test]
848 fn stream_async_round_trip_chunk() {
849 stream_async_round_trip(&[42; CHUNK_SIZE]);
850 }
851
852 #[cfg(feature = "async")]
853 #[test]
854 fn stream_async_round_trip_long() {
855 stream_async_round_trip(&[42; 100 * 1024]);
856 }
857
858 #[cfg(feature = "async")]
859 fn stream_async_io_copy(data: &[u8]) {
860 use futures::AsyncWriteExt;
861
862 let runtime = tokio::runtime::Builder::new_current_thread()
863 .build()
864 .unwrap();
865 let mut encrypted = vec![];
866 let result = runtime.block_on(async {
867 let mut w = Stream::encrypt_async(PayloadKey([7; 32].into()), &mut encrypted);
868 match futures::io::copy(data, &mut w).await {
869 Ok(written) => {
870 w.close().await.unwrap();
871 Ok(written)
872 }
873 Err(e) => Err(e),
874 }
875 });
876
877 match result {
878 Ok(written) => assert_eq!(written, data.len() as u64),
879 Err(e) => panic!("Unexpected error: {}", e),
880 }
881
882 let decrypted = {
883 let mut buf = vec![];
884 let result = runtime.block_on(async {
885 let r = Stream::decrypt_async(PayloadKey([7; 32].into()), &encrypted[..]);
886 futures::io::copy(r, &mut buf).await
887 });
888
889 match result {
890 Ok(written) => assert_eq!(written, data.len() as u64),
891 Err(e) => panic!("Unexpected error: {}", e),
892 }
893
894 buf
895 };
896
897 assert_eq!(decrypted, data);
898 }
899
900 #[cfg(feature = "async")]
901 #[test]
902 fn stream_async_io_copy_short() {
903 stream_async_io_copy(&[42; 1024]);
904 }
905
906 #[cfg(feature = "async")]
907 #[test]
908 fn stream_async_io_copy_chunk() {
909 stream_async_io_copy(&[42; CHUNK_SIZE]);
910 }
911
912 #[cfg(feature = "async")]
913 #[test]
914 fn stream_async_io_copy_long() {
915 stream_async_io_copy(&[42; 100 * 1024]);
916 }
917
918 #[test]
919 fn stream_fails_to_decrypt_truncated_file() {
920 let data = vec![42; 2 * CHUNK_SIZE];
921
922 let mut encrypted = vec![];
923 {
924 let mut w = Stream::encrypt(PayloadKey([7; 32].into()), &mut encrypted);
925 w.write_all(&data).unwrap();
926 };
928
929 let mut buf = vec![];
930 let mut r = Stream::decrypt(PayloadKey([7; 32].into()), &encrypted[..]);
931 assert_eq!(
932 r.read_to_end(&mut buf).unwrap_err().kind(),
933 io::ErrorKind::UnexpectedEof
934 );
935 }
936
937 #[test]
938 fn stream_seeking() {
939 let mut data = vec![0; 100 * 1024];
940 for (i, b) in data.iter_mut().enumerate() {
941 *b = i as u8;
942 }
943
944 let mut encrypted = vec![];
945 {
946 let mut w = Stream::encrypt(PayloadKey([7; 32].into()), &mut encrypted);
947 w.write_all(&data).unwrap();
948 w.finish().unwrap();
949 };
950
951 let mut r = Stream::decrypt(PayloadKey([7; 32].into()), Cursor::new(encrypted));
952
953 let mut buf = vec![0; 100];
955 for i in 0..700 {
956 r.read_exact(&mut buf).unwrap();
957 assert_eq!(&buf[..], &data[100 * i..100 * (i + 1)]);
958 }
959
960 r.seek(SeekFrom::Start(250)).unwrap();
962 r.read_exact(&mut buf).unwrap();
963 assert_eq!(&buf[..], &data[250..350]);
964
965 r.seek(SeekFrom::Current(510)).unwrap();
967 r.read_exact(&mut buf).unwrap();
968 assert_eq!(&buf[..], &data[860..960]);
969
970 r.seek(SeekFrom::End(-1337)).unwrap();
972 r.read_exact(&mut buf).unwrap();
973 assert_eq!(&buf[..], &data[data.len() - 1337..data.len() - 1237]);
974 }
975
976 #[test]
977 fn seek_from_end_fails_on_truncation() {
978 let mut plaintext: Vec<u8> = b"hello".to_vec();
981 plaintext.extend_from_slice(&[0; 65536]);
982
983 let mut encrypted = vec![];
985 {
986 let mut w = Stream::encrypt(PayloadKey([7; 32].into()), &mut encrypted);
987 w.write_all(&plaintext).unwrap();
988 w.finish().unwrap();
989 };
990
991 let mut reader = Stream::decrypt(PayloadKey([7; 32].into()), Cursor::new(&encrypted));
995 let eof_relative_offset = 1_i64 - plaintext.len() as i64;
996 reader.seek(SeekFrom::End(eof_relative_offset)).unwrap();
997 let mut buf = [0; 4];
998 reader.read_exact(&mut buf).unwrap();
999 assert_eq!(&buf, b"ello", "This is correct.");
1000
1001 let truncated_ciphertext = &encrypted[..encrypted.len() - 1];
1005 let mut truncated_reader = Stream::decrypt(
1006 PayloadKey([7; 32].into()),
1007 Cursor::new(truncated_ciphertext),
1008 );
1009 match truncated_reader.seek(SeekFrom::End(eof_relative_offset)) {
1011 Err(e) => {
1012 assert_eq!(e.kind(), io::ErrorKind::InvalidData);
1013 assert_eq!(
1014 &e.to_string(),
1015 "Last chunk is invalid, stream might be truncated",
1016 );
1017 }
1018 Ok(_) => panic!("This is a security issue."),
1019 }
1020 }
1021
1022 #[test]
1023 fn seek_from_end_with_exact_chunk() {
1024 let plaintext: Vec<u8> = vec![42; 65536];
1025
1026 let mut encrypted = vec![];
1028 {
1029 let mut w = Stream::encrypt(PayloadKey([7; 32].into()), &mut encrypted);
1030 w.write_all(&plaintext).unwrap();
1031 w.finish().unwrap();
1032 };
1033
1034 let mut reader = Stream::decrypt(PayloadKey([7; 32].into()), Cursor::new(&encrypted));
1036 reader.seek(SeekFrom::End(0)).unwrap();
1037
1038 let mut buf = Vec::new();
1040 reader.read_to_end(&mut buf).unwrap();
1041 assert_eq!(buf.len(), 0);
1042 }
1043}