navajo 0.0.4

cryptographic APIs
Documentation
use core::task::Poll::{Pending, Ready};

use super::{Decryptor, Encryptor, Segment};
use crate::{
    error::{DecryptError, EncryptError},
    Aad, Aead, Rng, SystemRng,
};
use alloc::{collections::VecDeque, vec, vec::Vec};
use futures::{
    stream::{Fuse, FusedStream},
    Stream, StreamExt,
};
use pin_project::pin_project;

pub trait AeadStream: Stream {
    fn encrypt<C, A>(self, segment: Segment, aad: Aad<A>, aead: C) -> EncryptStream<Self, A>
    where
        Self: Sized,
        Self::Item: AsRef<[u8]>,
        C: AsRef<Aead> + Send + Sync,
        A: AsRef<[u8]> + Send + Sync,
    {
        EncryptStream::new(SystemRng, self, segment, aad, aead)
    }
    fn decrypt<C, A>(self, aad: Aad<A>, cipher: C) -> DecryptStream<Self, C, A>
    where
        Self: Sized,
        Self::Item: AsRef<[u8]>,
        C: AsRef<Aead> + Send + Sync,
        A: AsRef<[u8]> + Send + Sync,
    {
        DecryptStream::new(SystemRng, self, cipher, aad)
    }
}

impl<T> AeadStream for T
where
    T: Stream,
    T::Item: AsRef<[u8]>,
{
}

#[pin_project]
pub struct EncryptStream<S, A, N = SystemRng>
where
    S: Stream + Sized,
    A: AsRef<[u8]>,
    N: Rng,
{
    #[pin]
    stream: Fuse<S>,
    encryptor: Option<Encryptor<Vec<u8>, N>>,
    queue: VecDeque<Vec<u8>>,
    aad: Aad<A>,
    done: bool,
}

impl<S, A, N> EncryptStream<S, A, N>
where
    S: Stream + Sized,
    A: AsRef<[u8]> + Send + Sync,
    N: Rng,
{
    pub fn new<C>(rng: N, stream: S, segment: Segment, aad: Aad<A>, aead: C) -> Self
    where
        C: AsRef<Aead> + Send + Sync,
    {
        let encryptor = Encryptor::new(rng, aead.as_ref(), Some(segment), vec![]);
        Self {
            stream: stream.fuse(),
            encryptor: Some(encryptor),
            queue: VecDeque::new(),
            aad,
            done: false,
        }
    }
}
impl<S, A, N> FusedStream for EncryptStream<S, A, N>
where
    S: Stream + Sized,
    S::Item: AsRef<[u8]>,
    A: AsRef<[u8]> + Send + Sync,
    N: Rng,
{
    fn is_terminated(&self) -> bool {
        self.done
    }
}

impl<S, D, N> Stream for EncryptStream<S, D, N>
where
    S: Stream,
    S::Item: AsRef<[u8]>,
    D: AsRef<[u8]> + Send + Sync,
    N: Rng,
{
    type Item = Result<Vec<u8>, EncryptError>;
    fn poll_next(
        self: core::pin::Pin<&mut Self>,
        cx: &mut core::task::Context<'_>,
    ) -> core::task::Poll<Option<Self::Item>> {
        let mut this = self.project();
        loop {
            if *this.done {
                return Ready(None);
            }
            if let Some(next) = this.queue.pop_front() {
                return Ready(Some(Ok(next)));
            }
            if this.stream.is_terminated() {
                *this.done = true;
                return Ready(None);
            }
            if this.encryptor.is_none() {
                *this.done = true;
                return Ready(None);
            }
            let mut encryptor = this.encryptor.take().unwrap();
            match this.stream.as_mut().poll_next(cx) {
                Ready(data) => match data {
                    Some(data) => {
                        let result = encryptor.update(Aad(this.aad.as_ref()), data.as_ref());
                        match result {
                            Err(e) => {
                                *this.done = true;
                                return Ready(Some(Err(e)));
                            }
                            Ok(_) => {
                                if let Some(ciphertext) = encryptor.next() {
                                    this.queue.push_back(ciphertext);
                                }
                                this.encryptor.replace(encryptor);
                            }
                        }
                    }
                    None => {
                        let result = encryptor.finalize(Aad(this.aad.as_ref()));
                        match result {
                            Err(e) => {
                                *this.done = true;
                                return Ready(Some(Err(e)));
                            }
                            Ok(iter) => {
                                this.queue.extend(iter);
                            }
                        }
                    }
                },
                Pending => return Pending,
            }
        }
    }
}
#[pin_project]
pub struct DecryptStream<S, C, A, N = SystemRng>
where
    N: Rng,
    S: Stream + Sized,
    C: AsRef<Aead> + Send + Sync,
    A: AsRef<[u8]>,
{
    #[pin]
    stream: Fuse<S>,
    decryptor: Option<Decryptor<C, Vec<u8>>>,
    queue: VecDeque<Vec<u8>>,
    aad: Aad<A>,
    done: bool,
    rng: N,
}

impl<S, C, A, N> DecryptStream<S, C, A, N>
where
    N: Rng,
    S: Stream,
    C: AsRef<Aead> + Send + Sync,
    A: AsRef<[u8]> + Send + Sync,
{
    pub fn new(rng: N, stream: S, cipher: C, aad: Aad<A>) -> Self {
        let decryptor = Decryptor::new(cipher, vec![]);
        Self {
            rng,
            aad,
            stream: stream.fuse(),
            decryptor: Some(decryptor),
            queue: VecDeque::new(),
            done: false,
        }
    }
}

impl<S, C, D, N> Stream for DecryptStream<S, C, D, N>
where
    N: Rng,
    S: Stream,
    S::Item: AsRef<[u8]>,
    C: AsRef<Aead> + Send + Sync,
    D: AsRef<[u8]> + Send + Sync,
{
    type Item = Result<Vec<u8>, DecryptError>;
    fn poll_next(
        self: core::pin::Pin<&mut Self>,
        cx: &mut core::task::Context<'_>,
    ) -> core::task::Poll<Option<Self::Item>> {
        let mut this = self.project();
        loop {
            if *this.done {
                return Ready(None);
            }
            if let Some(next) = this.queue.pop_front() {
                return Ready(Some(Ok(next)));
            }
            if this.decryptor.is_none() {
                return Ready(None);
            }
            if this.stream.is_terminated() {
                return Ready(None);
            }
            let mut decryptor = this.decryptor.take().unwrap();
            match this.stream.as_mut().poll_next(cx) {
                Ready(data) => match data {
                    Some(data) => {
                        decryptor.update(Aad(this.aad.as_ref()), data.as_ref())?;
                        let result = decryptor.next(Aad(this.aad.as_ref()));
                        match result {
                            Err(e) => {
                                *this.done = true;
                                return Ready(Some(Err(e)));
                            }
                            Ok(Some(plaintext)) => {
                                this.queue.push_back(plaintext);
                                this.decryptor.replace(decryptor);
                            }
                            Ok(None) => {
                                this.decryptor.replace(decryptor);
                            }
                        }
                    }
                    None => {
                        let result = decryptor.finalize(Aad(this.aad.as_ref()));
                        match result {
                            Err(e) => {
                                *this.done = true;
                                return Ready(Some(Err(e)));
                            }
                            Ok(iter) => {
                                this.queue.extend(iter);
                            }
                        }
                    }
                },
                Pending => return Pending,
            }
        }
    }
}

#[cfg(test)]
mod tests {
    #[cfg(feature = "std")]
    #[tokio::test]
    async fn test_stream_roundtrip() {
        use super::*;
        use crate::aead::Algorithm;
        use futures::{stream, TryStreamExt};

        let data_stream = stream::iter(vec![
            Vec::from("hello".as_bytes()),
            Vec::from(" ".as_bytes()),
            Vec::from("world".as_bytes()),
        ]);
        let aead = Aead::new(Algorithm::Aes256Gcm, None);
        let encrypt_stream =
            data_stream.encrypt(Segment::FourKilobytes, Aad::empty(), aead.clone());
        let ciphertext: Vec<u8> = encrypt_stream.try_concat().await.unwrap();

        let decrypt_stream =
            stream::iter(ciphertext.chunks(40).map(Vec::from)).decrypt(Aad::empty(), aead);
        let result = decrypt_stream.try_concat().await.unwrap();
        assert_eq!(result, b"hello world");
    }

    #[cfg(feature = "std")]
    #[tokio::test]
    async fn test_stream_with_aad_roundtrip() {
        use super::*;
        use crate::aead::Algorithm;
        use futures::{stream, TryStreamExt};

        let mut data = vec![0u8; 5556];
        let rng = crate::SystemRng::new();
        rng.fill(&mut data).unwrap();
        let data_stream = stream::iter(data.chunks(122).map(Vec::from));
        let algorithm = Algorithm::Aes256Gcm;
        let aead = Aead::new(algorithm, None);
        let encrypt_stream = data_stream.encrypt(
            Segment::FourKilobytes,
            Aad(Vec::from(&b"additional data"[..])),
            aead.clone(),
        );
        let result: Vec<Vec<u8>> = encrypt_stream.try_collect().await.unwrap();

        let ciphertext = result.concat();

        println!("ciphertext: {}", ciphertext.len());

        let plaintext: Vec<u8> = stream::iter(ciphertext.chunks(40).map(Vec::from))
            .decrypt(Aad(Vec::from(&b"additional data"[..])), aead)
            .try_concat()
            .await
            .unwrap();
        assert_eq!(plaintext, data);
    }
}