1use aes_gcm_siv::{
4 aead::{generic_array::GenericArray, Aead, KeyInit, KeySizeUser},
5 Aes256GcmSiv,
6};
7use anubis_core::secrecy::{ExposeSecret, SecretSlice};
8use pin_project::pin_project;
9use std::cmp;
10use std::io::{self, Read, Seek, SeekFrom, Write};
11use zeroize::Zeroize;
12
13#[cfg(feature = "async")]
14use futures::{
15 io::{AsyncRead, AsyncWrite, Error},
16 ready,
17 task::{Context, Poll},
18};
19#[cfg(feature = "async")]
20use std::pin::Pin;
21
22const CHUNK_SIZE: usize = 64 * 1024;
23const TAG_SIZE: usize = 16;
24const ENCRYPTED_CHUNK_SIZE: usize = CHUNK_SIZE + TAG_SIZE;
25
26pub(crate) struct PayloadKey(pub(crate) GenericArray<u8, <Aes256GcmSiv as KeySizeUser>::KeySize>);
27
28impl Drop for PayloadKey {
29 fn drop(&mut self) {
30 self.0.as_mut_slice().zeroize();
31 }
32}
33
34#[derive(Clone, Copy, Default)]
39struct Nonce(u128);
40
41impl Nonce {
42 fn set_counter(&mut self, val: u64) {
44 self.0 = u128::from(val) << 8;
45 }
46
47 fn increment_counter(&mut self) -> Result<(), ()> {
48 self.0 += 1 << 8;
50 if self.0 >> (8 * 12) != 0 {
51 Err(())
53 } else {
54 Ok(())
55 }
56 }
57
58 fn is_last(&self) -> bool {
59 self.0 & 1 != 0
60 }
61
62 fn set_last(&mut self, last: bool) -> Result<(), ()> {
63 if !self.is_last() {
64 self.0 |= u128::from(last);
65 Ok(())
66 } else {
67 Err(())
68 }
69 }
70
71 fn to_bytes(self) -> [u8; 12] {
72 self.0.to_be_bytes()[4..].try_into().unwrap_or_else(|_| {
73 [0u8; 12]
75 })
76 }
77}
78
79#[cfg(feature = "async")]
80#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
81struct EncryptedChunk {
82 bytes: Vec<u8>,
83 offset: usize,
84}
85
86pub(crate) struct Stream {
94 aead: Aes256GcmSiv,
95 nonce: Nonce,
96}
97
98impl Stream {
99 fn new(key: PayloadKey) -> Self {
100 Stream {
101 aead: Aes256GcmSiv::new(&key.0),
102 nonce: Nonce::default(),
103 }
104 }
105
106 pub(crate) fn encrypt<W: Write>(key: PayloadKey, inner: W) -> StreamWriter<W> {
114 StreamWriter {
115 stream: Self::new(key),
116 inner,
117 chunk: Vec::with_capacity(CHUNK_SIZE),
118 #[cfg(feature = "async")]
119 encrypted_chunk: None,
120 }
121 }
122
123 #[cfg(feature = "async")]
131 #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
132 pub(crate) fn encrypt_async<W: AsyncWrite>(key: PayloadKey, inner: W) -> StreamWriter<W> {
133 StreamWriter {
134 stream: Self::new(key),
135 inner,
136 chunk: Vec::with_capacity(CHUNK_SIZE),
137 encrypted_chunk: None,
138 }
139 }
140
141 pub(crate) fn decrypt<R: Read>(key: PayloadKey, inner: R) -> StreamReader<R> {
149 StreamReader {
150 stream: Self::new(key),
151 inner,
152 encrypted_chunk: vec![0; ENCRYPTED_CHUNK_SIZE],
153 encrypted_pos: 0,
154 start: StartPos::Implicit(0),
155 plaintext_len: None,
156 cur_plaintext_pos: 0,
157 chunk: None,
158 }
159 }
160
161 #[cfg(feature = "async")]
169 #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
170 pub(crate) fn decrypt_async<R: AsyncRead>(key: PayloadKey, inner: R) -> StreamReader<R> {
171 StreamReader {
172 stream: Self::new(key),
173 inner,
174 encrypted_chunk: vec![0; ENCRYPTED_CHUNK_SIZE],
175 encrypted_pos: 0,
176 start: StartPos::Implicit(0),
177 plaintext_len: None,
178 cur_plaintext_pos: 0,
179 chunk: None,
180 }
181 }
182
183 fn encrypt_chunk(&mut self, chunk: &[u8], last: bool) -> io::Result<Vec<u8>> {
184 assert!(chunk.len() <= CHUNK_SIZE);
185
186 self.nonce.set_last(last).map_err(|_| {
187 io::Error::new(io::ErrorKind::WriteZero, "last chunk has been processed")
188 })?;
189
190 let encrypted = self
191 .aead
192 .encrypt(&self.nonce.to_bytes().into(), chunk)
193 .map_err(|_| {
194 io::Error::new(
195 io::ErrorKind::Other,
196 "AES-256-GCM-SIV encryption failed unexpectedly",
197 )
198 })?;
199
200 self.nonce.increment_counter().map_err(|_| {
201 io::Error::new(
202 io::ErrorKind::WriteZero,
203 "nonce counter overflow (encrypted more than 2^88 chunks)",
204 )
205 })?;
206
207 Ok(encrypted)
208 }
209
210 fn decrypt_chunk(&mut self, chunk: &[u8], last: bool) -> io::Result<SecretSlice<u8>> {
211 assert!(chunk.len() <= ENCRYPTED_CHUNK_SIZE);
212
213 self.nonce.set_last(last).map_err(|_| {
214 io::Error::new(io::ErrorKind::InvalidData, "last chunk has been processed")
215 })?;
216
217 let decrypted = self
218 .aead
219 .decrypt(&self.nonce.to_bytes().into(), chunk)
220 .map(SecretSlice::from)
221 .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "decryption error"))?;
222
223 self.nonce.increment_counter().map_err(|_| {
224 io::Error::new(
225 io::ErrorKind::InvalidData,
226 "nonce counter overflow (decrypted more than 2^88 chunks)",
227 )
228 })?;
229
230 Ok(decrypted)
231 }
232
233 fn is_complete(&self) -> bool {
234 self.nonce.is_last()
235 }
236}
237
238#[pin_project(project = StreamWriterProj)]
240pub struct StreamWriter<W> {
241 stream: Stream,
242 #[pin]
243 inner: W,
244 chunk: Vec<u8>,
245 #[cfg(feature = "async")]
246 #[cfg_attr(docsrs, doc(cfg(feature = "async")))]
247 encrypted_chunk: Option<EncryptedChunk>,
248}
249
250impl<W: Write> StreamWriter<W> {
251 pub fn finish(mut self) -> io::Result<W> {
257 let encrypted = self.stream.encrypt_chunk(&self.chunk, true)?;
258 self.inner.write_all(&encrypted)?;
259 Ok(self.inner)
260 }
261}
262
263impl<W: Write> Write for StreamWriter<W> {
264 fn write(&mut self, mut buf: &[u8]) -> io::Result<usize> {
265 let mut bytes_written = 0;
266
267 while !buf.is_empty() {
268 let to_write = cmp::min(CHUNK_SIZE - self.chunk.len(), buf.len());
269 self.chunk.extend_from_slice(&buf[..to_write]);
270 bytes_written += to_write;
271 buf = &buf[to_write..];
272
273 assert!(buf.is_empty() || self.chunk.len() == CHUNK_SIZE);
275
276 if !buf.is_empty() {
279 let encrypted = self.stream.encrypt_chunk(&self.chunk, false)?;
280 self.inner.write_all(&encrypted)?;
281 self.chunk.clear();
282 }
283 }
284
285 Ok(bytes_written)
286 }
287
288 fn flush(&mut self) -> io::Result<()> {
289 self.inner.flush()
290 }
291}
292
293#[cfg(feature = "async")]
294#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
295impl<W: AsyncWrite> StreamWriter<W> {
296 fn poll_flush_chunk(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
297 let StreamWriterProj {
298 mut inner,
299 encrypted_chunk,
300 ..
301 } = self.project();
302
303 if let Some(chunk) = encrypted_chunk {
304 loop {
305 chunk.offset +=
306 ready!(inner.as_mut().poll_write(cx, &chunk.bytes[chunk.offset..]))?;
307 if chunk.offset == chunk.bytes.len() {
308 break;
309 }
310 }
311 }
312 *encrypted_chunk = None;
313
314 Poll::Ready(Ok(()))
315 }
316}
317
318#[cfg(feature = "async")]
319#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
320impl<W: AsyncWrite> AsyncWrite for StreamWriter<W> {
321 fn poll_write(
322 mut self: Pin<&mut Self>,
323 cx: &mut Context,
324 mut buf: &[u8],
325 ) -> Poll<io::Result<usize>> {
326 if buf.is_empty() {
328 return Poll::Ready(Ok(0));
329 }
330
331 loop {
332 ready!(self.as_mut().poll_flush_chunk(cx))?;
333
334 let to_write = cmp::min(CHUNK_SIZE - self.chunk.len(), buf.len());
343
344 self.as_mut()
345 .project()
346 .chunk
347 .extend_from_slice(&buf[..to_write]);
348 buf = &buf[to_write..];
349
350 assert!(buf.is_empty() || self.chunk.len() == CHUNK_SIZE);
352
353 if !buf.is_empty() {
356 let this = self.as_mut().project();
357 *this.encrypted_chunk = Some(EncryptedChunk {
358 bytes: this.stream.encrypt_chunk(this.chunk, false)?,
359 offset: 0,
360 });
361 this.chunk.clear();
362 }
363
364 if to_write > 0 {
366 return Poll::Ready(Ok(to_write));
367 }
368
369 }
377 }
378
379 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
380 ready!(self.as_mut().poll_flush_chunk(cx))?;
381 self.project().inner.poll_flush(cx)
382 }
383
384 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
385 ready!(self.as_mut().poll_flush_chunk(cx))?;
387
388 if !self.stream.is_complete() {
389 let this = self.as_mut().project();
391 *this.encrypted_chunk = Some(EncryptedChunk {
392 bytes: this.stream.encrypt_chunk(this.chunk, true)?,
393 offset: 0,
394 });
395 }
396
397 ready!(self.as_mut().poll_flush_chunk(cx))?;
399 self.project().inner.poll_close(cx)
400 }
401}
402
403enum StartPos {
412 Implicit(u64),
414 Explicit(u64),
416}
417
418#[pin_project]
420pub struct StreamReader<R> {
421 stream: Stream,
422 #[pin]
423 inner: R,
424 encrypted_chunk: Vec<u8>,
425 encrypted_pos: usize,
426 start: StartPos,
427 plaintext_len: Option<u64>,
428 cur_plaintext_pos: u64,
429 chunk: Option<SecretSlice<u8>>,
430}
431
432impl<R> StreamReader<R> {
433 fn count_bytes(&mut self, read: usize) {
434 if let StartPos::Implicit(offset) = &mut self.start {
436 *offset += read as u64;
437 }
438 }
439
440 fn decrypt_chunk(&mut self) -> io::Result<()> {
441 self.count_bytes(self.encrypted_pos);
442 let chunk = &self.encrypted_chunk[..self.encrypted_pos];
443
444 if chunk.is_empty() {
445 if !self.stream.is_complete() {
446 return Err(io::Error::new(
448 io::ErrorKind::UnexpectedEof,
449 "age file is truncated",
450 ));
451 }
452 } else {
453 let last = chunk.len() < ENCRYPTED_CHUNK_SIZE;
457
458 self.chunk = match (self.stream.decrypt_chunk(chunk, last), last) {
459 (Ok(chunk), _)
460 if chunk.expose_secret().is_empty() && self.cur_plaintext_pos > 0 =>
461 {
462 assert!(last);
463 return Err(io::Error::new(
464 io::ErrorKind::InvalidData,
465 crate::fl!("err-stream-last-chunk-empty"),
466 ));
467 }
468 (Ok(chunk), _) => Some(chunk),
469 (Err(_), false) => Some(self.stream.decrypt_chunk(chunk, true)?),
470 (Err(e), true) => return Err(e),
471 };
472 }
473
474 self.encrypted_pos = 0;
476
477 Ok(())
478 }
479
480 fn read_from_chunk(&mut self, buf: &mut [u8]) -> usize {
481 if self.chunk.is_none() {
482 return 0;
483 }
484
485 let chunk = self.chunk.as_ref().unwrap();
486 let cur_chunk_offset = self.cur_plaintext_pos as usize % CHUNK_SIZE;
487
488 let to_read = cmp::min(chunk.expose_secret().len() - cur_chunk_offset, buf.len());
489
490 buf[..to_read]
491 .copy_from_slice(&chunk.expose_secret()[cur_chunk_offset..cur_chunk_offset + to_read]);
492 self.cur_plaintext_pos += to_read as u64;
493 if self.cur_plaintext_pos % CHUNK_SIZE as u64 == 0 {
494 self.chunk = None;
496 }
497
498 to_read
499 }
500}
501
502impl<R: Read> Read for StreamReader<R> {
503 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
504 if self.chunk.is_none() {
505 while self.encrypted_pos < ENCRYPTED_CHUNK_SIZE {
506 match self
507 .inner
508 .read(&mut self.encrypted_chunk[self.encrypted_pos..])
509 {
510 Ok(0) => break,
511 Ok(n) => self.encrypted_pos += n,
512 Err(e) => match e.kind() {
513 io::ErrorKind::Interrupted => (),
514 _ => return Err(e),
515 },
516 }
517 }
518 self.decrypt_chunk()?;
519 }
520
521 Ok(self.read_from_chunk(buf))
522 }
523}
524
525#[cfg(feature = "async")]
526#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
527impl<R: AsyncRead + Unpin> AsyncRead for StreamReader<R> {
528 fn poll_read(
529 mut self: Pin<&mut Self>,
530 cx: &mut Context,
531 buf: &mut [u8],
532 ) -> Poll<Result<usize, Error>> {
533 if self.chunk.is_none() {
534 while self.encrypted_pos < ENCRYPTED_CHUNK_SIZE {
535 let this = self.as_mut().project();
536 match ready!(this
537 .inner
538 .poll_read(cx, &mut this.encrypted_chunk[*this.encrypted_pos..]))
539 {
540 Ok(0) => break,
541 Ok(n) => self.encrypted_pos += n,
542 Err(e) => match e.kind() {
543 io::ErrorKind::Interrupted => (),
544 _ => return Poll::Ready(Err(e)),
545 },
546 }
547 }
548 self.decrypt_chunk()?;
549 }
550
551 Poll::Ready(Ok(self.read_from_chunk(buf)))
552 }
553}
554
555impl<R: Read + Seek> StreamReader<R> {
556 fn start(&mut self) -> io::Result<u64> {
557 match self.start {
558 StartPos::Implicit(offset) => {
559 let current = self.inner.stream_position()?;
560 let start = current - offset;
561
562 self.start = StartPos::Explicit(start);
564
565 Ok(start)
566 }
567 StartPos::Explicit(start) => Ok(start),
568 }
569 }
570
571 fn len(&mut self) -> io::Result<u64> {
573 match self.plaintext_len {
574 None => {
575 let cur_pos = self.inner.stream_position()?;
578 let cur_nonce = self.stream.nonce.0;
579 let ct_start = self.start()?;
580 let ct_end = self.inner.seek(SeekFrom::End(0))?;
581 let ct_len = ct_end - ct_start;
582
583 let num_chunks =
585 (ct_len + (ENCRYPTED_CHUNK_SIZE as u64 - 1)) / ENCRYPTED_CHUNK_SIZE as u64;
586
587 let last_chunk_start = ct_start + ((num_chunks - 1) * ENCRYPTED_CHUNK_SIZE as u64);
590 let mut last_chunk = Vec::with_capacity((ct_end - last_chunk_start) as usize);
591 self.inner.seek(SeekFrom::Start(last_chunk_start))?;
592 self.inner.read_to_end(&mut last_chunk)?;
593 self.stream.nonce.set_counter(num_chunks - 1);
594 self.stream.decrypt_chunk(&last_chunk, true).map_err(|_| {
595 io::Error::new(
596 io::ErrorKind::InvalidData,
597 "Last chunk is invalid, stream might be truncated",
598 )
599 })?;
600
601 let total_tag_size = num_chunks * TAG_SIZE as u64;
604 let pt_len = ct_len - total_tag_size;
605
606 self.inner.seek(SeekFrom::Start(cur_pos))?;
608 self.stream.nonce = Nonce(cur_nonce);
609
610 self.plaintext_len = Some(pt_len);
612
613 Ok(pt_len)
614 }
615 Some(pt_len) => Ok(pt_len),
616 }
617 }
618}
619
620impl<R: Read + Seek> Seek for StreamReader<R> {
621 fn seek(&mut self, pos: SeekFrom) -> io::Result<u64> {
622 let start = self.start()?;
624 let target_pos = match pos {
625 SeekFrom::Start(offset) => offset,
626 SeekFrom::Current(offset) => {
627 let res = (self.cur_plaintext_pos as i64) + offset;
628 if res >= 0 {
629 res as u64
630 } else {
631 return Err(io::Error::new(
632 io::ErrorKind::InvalidData,
633 "cannot seek before the start",
634 ));
635 }
636 }
637 SeekFrom::End(offset) => {
638 let res = (self.len()? as i64) + offset;
639 if res >= 0 {
640 res as u64
641 } else {
642 return Err(io::Error::new(
643 io::ErrorKind::InvalidData,
644 "cannot seek before the start",
645 ));
646 }
647 }
648 };
649
650 let cur_chunk_index = self.cur_plaintext_pos / CHUNK_SIZE as u64;
651
652 let target_chunk_index = target_pos / CHUNK_SIZE as u64;
653 let target_chunk_offset = target_pos % CHUNK_SIZE as u64;
654
655 if target_chunk_index == cur_chunk_index {
656 self.cur_plaintext_pos = target_pos;
658 } else {
659 self.chunk = None;
661
662 self.inner.seek(SeekFrom::Start(
664 start + (target_chunk_index * ENCRYPTED_CHUNK_SIZE as u64),
665 ))?;
666 self.stream.nonce.set_counter(target_chunk_index);
667 self.cur_plaintext_pos = target_chunk_index * CHUNK_SIZE as u64;
668
669 if target_chunk_offset > 0 {
671 let mut to_drop = vec![0; target_chunk_offset as usize];
672 self.read_exact(&mut to_drop)?;
673 }
674 else if target_pos == self.len()? {
684 self.stream.nonce.set_last(true).map_err(|_| {
685 io::Error::new(
686 io::ErrorKind::InvalidData,
687 "nonce is already set as last chunk",
688 )
689 })?;
690 }
691 }
692
693 Ok(target_pos)
695 }
696}
697
698#[cfg(test)]
699mod tests {
700 use anubis_core::secrecy::ExposeSecret;
701 use std::io::{self, Cursor, Read, Seek, SeekFrom, Write};
702
703 use super::{PayloadKey, Stream, CHUNK_SIZE};
704
705 #[cfg(feature = "async")]
706 use futures::{
707 io::{AsyncRead, AsyncWrite},
708 pin_mut,
709 task::Poll,
710 };
711 #[cfg(feature = "async")]
712 use futures_test::task::noop_context;
713
714 #[test]
715 fn chunk_round_trip() {
716 let data = vec![42; CHUNK_SIZE];
717
718 let encrypted = {
719 let mut s = Stream::new(PayloadKey([7; 32].into()));
720 s.encrypt_chunk(&data, false).unwrap()
721 };
722
723 let decrypted = {
724 let mut s = Stream::new(PayloadKey([7; 32].into()));
725 s.decrypt_chunk(&encrypted, false).unwrap()
726 };
727
728 assert_eq!(decrypted.expose_secret(), &data);
729 }
730
731 #[test]
732 fn last_chunk_round_trip() {
733 let data = vec![42; CHUNK_SIZE];
734
735 let encrypted = {
736 let mut s = Stream::new(PayloadKey([7; 32].into()));
737 let res = s.encrypt_chunk(&data, true).unwrap();
738
739 assert_eq!(
741 s.encrypt_chunk(&data, false).unwrap_err().kind(),
742 io::ErrorKind::WriteZero
743 );
744 assert_eq!(
745 s.encrypt_chunk(&data, true).unwrap_err().kind(),
746 io::ErrorKind::WriteZero
747 );
748
749 res
750 };
751
752 let decrypted = {
753 let mut s = Stream::new(PayloadKey([7; 32].into()));
754 let res = s.decrypt_chunk(&encrypted, true).unwrap();
755
756 match s.decrypt_chunk(&encrypted, false) {
758 Err(e) => assert_eq!(e.kind(), io::ErrorKind::InvalidData),
759 _ => panic!("Expected error"),
760 }
761 match s.decrypt_chunk(&encrypted, true) {
762 Err(e) => assert_eq!(e.kind(), io::ErrorKind::InvalidData),
763 _ => panic!("Expected error"),
764 }
765
766 res
767 };
768
769 assert_eq!(decrypted.expose_secret(), &data);
770 }
771
772 fn stream_round_trip(data: &[u8]) {
773 let mut encrypted = vec![];
774 {
775 let mut w = Stream::encrypt(PayloadKey([7; 32].into()), &mut encrypted);
776 w.write_all(data).unwrap();
777 w.finish().unwrap();
778 };
779
780 let decrypted = {
781 let mut buf = vec![];
782 let mut r = Stream::decrypt(PayloadKey([7; 32].into()), &encrypted[..]);
783 r.read_to_end(&mut buf).unwrap();
784 buf
785 };
786
787 assert_eq!(decrypted, data);
788 }
789
790 #[test]
794 fn stream_round_trip_empty() {
795 stream_round_trip(&[]);
796 }
797
798 #[test]
799 fn stream_round_trip_short() {
800 stream_round_trip(&[42; 1024]);
801 }
802
803 #[test]
804 fn stream_round_trip_chunk() {
805 stream_round_trip(&[42; CHUNK_SIZE]);
806 }
807
808 #[test]
809 fn stream_round_trip_long() {
810 stream_round_trip(&[42; 100 * 1024]);
811 }
812
813 #[cfg(feature = "async")]
814 fn stream_async_round_trip(data: &[u8]) {
815 let mut encrypted = vec![];
816 {
817 let w = Stream::encrypt_async(PayloadKey([7; 32].into()), &mut encrypted);
818 pin_mut!(w);
819
820 let mut cx = noop_context();
821
822 let mut tmp = data;
823 loop {
824 match w.as_mut().poll_write(&mut cx, tmp) {
825 Poll::Ready(Ok(0)) => break,
826 Poll::Ready(Ok(written)) => tmp = &tmp[written..],
827 Poll::Ready(Err(e)) => panic!("Unexpected error: {}", e),
828 Poll::Pending => panic!("Unexpected Pending"),
829 }
830 }
831 loop {
832 match w.as_mut().poll_close(&mut cx) {
833 Poll::Ready(Ok(())) => break,
834 Poll::Ready(Err(e)) => panic!("Unexpected error: {}", e),
835 Poll::Pending => panic!("Unexpected Pending"),
836 }
837 }
838 };
839
840 let decrypted = {
841 let mut buf = vec![];
842 let r = Stream::decrypt_async(PayloadKey([7; 32].into()), &encrypted[..]);
843 pin_mut!(r);
844
845 let mut cx = noop_context();
846
847 let mut tmp = [0; 4096];
848 loop {
849 match r.as_mut().poll_read(&mut cx, &mut tmp) {
850 Poll::Ready(Ok(0)) => break buf,
851 Poll::Ready(Ok(read)) => buf.extend_from_slice(&tmp[..read]),
852 Poll::Ready(Err(e)) => panic!("Unexpected error: {}", e),
853 Poll::Pending => panic!("Unexpected Pending"),
854 }
855 }
856 };
857
858 assert_eq!(decrypted, data);
859 }
860
861 #[cfg(feature = "async")]
862 #[test]
863 fn stream_async_round_trip_short() {
864 stream_async_round_trip(&[42; 1024]);
865 }
866
867 #[cfg(feature = "async")]
868 #[test]
869 fn stream_async_round_trip_chunk() {
870 stream_async_round_trip(&[42; CHUNK_SIZE]);
871 }
872
873 #[cfg(feature = "async")]
874 #[test]
875 fn stream_async_round_trip_long() {
876 stream_async_round_trip(&[42; 100 * 1024]);
877 }
878
879 #[cfg(feature = "async")]
880 fn stream_async_io_copy(data: &[u8]) {
881 use futures::AsyncWriteExt;
882
883 let runtime = tokio::runtime::Builder::new_current_thread()
884 .build()
885 .unwrap();
886 let mut encrypted = vec![];
887 let result = runtime.block_on(async {
888 let mut w = Stream::encrypt_async(PayloadKey([7; 32].into()), &mut encrypted);
889 match futures::io::copy(data, &mut w).await {
890 Ok(written) => {
891 w.close().await.unwrap();
892 Ok(written)
893 }
894 Err(e) => Err(e),
895 }
896 });
897
898 match result {
899 Ok(written) => assert_eq!(written, data.len() as u64),
900 Err(e) => panic!("Unexpected error: {}", e),
901 }
902
903 let decrypted = {
904 let mut buf = vec![];
905 let result = runtime.block_on(async {
906 let r = Stream::decrypt_async(PayloadKey([7; 32].into()), &encrypted[..]);
907 futures::io::copy(r, &mut buf).await
908 });
909
910 match result {
911 Ok(written) => assert_eq!(written, data.len() as u64),
912 Err(e) => panic!("Unexpected error: {}", e),
913 }
914
915 buf
916 };
917
918 assert_eq!(decrypted, data);
919 }
920
921 #[cfg(feature = "async")]
922 #[test]
923 fn stream_async_io_copy_short() {
924 stream_async_io_copy(&[42; 1024]);
925 }
926
927 #[cfg(feature = "async")]
928 #[test]
929 fn stream_async_io_copy_chunk() {
930 stream_async_io_copy(&[42; CHUNK_SIZE]);
931 }
932
933 #[cfg(feature = "async")]
934 #[test]
935 fn stream_async_io_copy_long() {
936 stream_async_io_copy(&[42; 100 * 1024]);
937 }
938
939 #[test]
940 fn stream_fails_to_decrypt_truncated_file() {
941 let data = vec![42; 2 * CHUNK_SIZE];
942
943 let mut encrypted = vec![];
944 {
945 let mut w = Stream::encrypt(PayloadKey([7; 32].into()), &mut encrypted);
946 w.write_all(&data).unwrap();
947 };
949
950 let mut buf = vec![];
951 let mut r = Stream::decrypt(PayloadKey([7; 32].into()), &encrypted[..]);
952 assert_eq!(
953 r.read_to_end(&mut buf).unwrap_err().kind(),
954 io::ErrorKind::UnexpectedEof
955 );
956 }
957
958 #[test]
959 fn stream_seeking() {
960 let mut data = vec![0; 100 * 1024];
961 for (i, b) in data.iter_mut().enumerate() {
962 *b = i as u8;
963 }
964
965 let mut encrypted = vec![];
966 {
967 let mut w = Stream::encrypt(PayloadKey([7; 32].into()), &mut encrypted);
968 w.write_all(&data).unwrap();
969 w.finish().unwrap();
970 };
971
972 let mut r = Stream::decrypt(PayloadKey([7; 32].into()), Cursor::new(encrypted));
973
974 let mut buf = vec![0; 100];
976 for i in 0..700 {
977 r.read_exact(&mut buf).unwrap();
978 assert_eq!(&buf[..], &data[100 * i..100 * (i + 1)]);
979 }
980
981 r.seek(SeekFrom::Start(250)).unwrap();
983 r.read_exact(&mut buf).unwrap();
984 assert_eq!(&buf[..], &data[250..350]);
985
986 r.seek(SeekFrom::Current(510)).unwrap();
988 r.read_exact(&mut buf).unwrap();
989 assert_eq!(&buf[..], &data[860..960]);
990
991 r.seek(SeekFrom::End(-1337)).unwrap();
993 r.read_exact(&mut buf).unwrap();
994 assert_eq!(&buf[..], &data[data.len() - 1337..data.len() - 1237]);
995 }
996
997 #[test]
998 fn seek_from_end_fails_on_truncation() {
999 let mut plaintext: Vec<u8> = b"hello".to_vec();
1002 plaintext.extend_from_slice(&[0; 65536]);
1003
1004 let mut encrypted = vec![];
1006 {
1007 let mut w = Stream::encrypt(PayloadKey([7; 32].into()), &mut encrypted);
1008 w.write_all(&plaintext).unwrap();
1009 w.finish().unwrap();
1010 };
1011
1012 let mut reader = Stream::decrypt(PayloadKey([7; 32].into()), Cursor::new(&encrypted));
1016 let eof_relative_offset = 1_i64 - plaintext.len() as i64;
1017 reader.seek(SeekFrom::End(eof_relative_offset)).unwrap();
1018 let mut buf = [0; 4];
1019 reader.read_exact(&mut buf).unwrap();
1020 assert_eq!(&buf, b"ello", "This is correct.");
1021
1022 let truncated_ciphertext = &encrypted[..encrypted.len() - 1];
1026 let mut truncated_reader = Stream::decrypt(
1027 PayloadKey([7; 32].into()),
1028 Cursor::new(truncated_ciphertext),
1029 );
1030 match truncated_reader.seek(SeekFrom::End(eof_relative_offset)) {
1032 Err(e) => {
1033 assert_eq!(e.kind(), io::ErrorKind::InvalidData);
1034 assert_eq!(
1035 &e.to_string(),
1036 "Last chunk is invalid, stream might be truncated",
1037 );
1038 }
1039 Ok(_) => panic!("This is a security issue."),
1040 }
1041 }
1042
1043 #[test]
1044 fn seek_from_end_with_exact_chunk() {
1045 let plaintext: Vec<u8> = vec![42; 65536];
1046
1047 let mut encrypted = vec![];
1049 {
1050 let mut w = Stream::encrypt(PayloadKey([7; 32].into()), &mut encrypted);
1051 w.write_all(&plaintext).unwrap();
1052 w.finish().unwrap();
1053 };
1054
1055 let mut reader = Stream::decrypt(PayloadKey([7; 32].into()), Cursor::new(&encrypted));
1057 reader.seek(SeekFrom::End(0)).unwrap();
1058
1059 let mut buf = Vec::new();
1061 reader.read_to_end(&mut buf).unwrap();
1062 assert_eq!(buf.len(), 0);
1063 }
1064}