navajo 0.0.4

cryptographic APIs
Documentation
use core::task::Poll::{Pending, Ready};

use alloc::{collections::VecDeque, vec, vec::Vec};
use futures::{stream::FusedStream, Stream, TryStream};
use pin_project::pin_project;

use crate::{
    error::{DecryptStreamError, EncryptTryStreamError},
    rand::Rng,
    Aad, Aead, SystemRng,
};

use super::{Decryptor, Encryptor, Segment};

pub trait AeadTryStream: TryStream {
    fn encrypt<C, A>(self, segment: Segment, aad: Aad<A>, cipher: C) -> EncryptTryStream<Self, A>
    where
        Self: Sized,
        Self::Ok: AsRef<[u8]>,
        Self::Error: Send + Sync,
        C: AsRef<Aead> + Send + Sync,
        A: AsRef<[u8]> + Send + Sync,
    {
        EncryptTryStream::new(SystemRng, self, segment, aad, cipher)
    }
    fn decrypt<C, A>(self, aad: Aad<A>, cipher: C) -> DecryptTryStream<Self, C, A>
    where
        Self: Sized,
        Self::Ok: AsRef<[u8]>,
        Self::Error: Send + Sync,
        C: AsRef<Aead> + Send + Sync,
        A: AsRef<[u8]> + Send + Sync,
    {
        DecryptTryStream::new(self, cipher, aad)
    }
}
impl<T> AeadTryStream for T
where
    T: TryStream,
    T::Ok: AsRef<[u8]>,
    T::Error: Send + Sync,
{
}

#[pin_project]
pub struct EncryptTryStream<S, A, N = SystemRng>
where
    N: Rng,
    S: TryStream + Sized,
    S::Ok: AsRef<[u8]>,
    S::Error: Send + Sync,
    A: AsRef<[u8]> + Send + Sync,
{
    #[pin]
    stream: S,
    encryptor: Option<Encryptor<Vec<u8>, N>>,
    queue: VecDeque<Vec<u8>>,
    aad: Aad<A>,
    done: bool,
}

impl<S, A, N> EncryptTryStream<S, A, N>
where
    N: Rng,
    S: TryStream + Sized,
    S::Ok: AsRef<[u8]>,
    S::Error: Send + Sync,
    A: AsRef<[u8]> + Send + Sync,
{
    pub fn new<C>(rng: N, stream: S, segment: Segment, aad: Aad<A>, cipher: C) -> Self
    where
        C: AsRef<Aead>,
    {
        let encryptor = Encryptor::new(rng, cipher.as_ref(), Some(segment), vec![]);
        Self {
            stream,
            encryptor: Some(encryptor),
            queue: VecDeque::new(),
            aad,
            done: false,
        }
    }
}
impl<S, D, N> FusedStream for EncryptTryStream<S, D, N>
where
    N: Rng,
    S: TryStream + Sized,
    S::Ok: AsRef<[u8]>,
    S::Error: Send + Sync,
    D: AsRef<[u8]> + Send + Sync,
{
    fn is_terminated(&self) -> bool {
        self.done
    }
}

impl<S, D, N> Stream for EncryptTryStream<S, D, N>
where
    N: Rng,
    S: TryStream,
    S::Ok: AsRef<[u8]>,
    S::Error: Send + Sync,
    D: AsRef<[u8]> + Send + Sync,
{
    type Item = Result<Vec<u8>, EncryptTryStreamError<S::Error>>;
    fn poll_next(
        self: core::pin::Pin<&mut Self>,
        cx: &mut core::task::Context<'_>,
    ) -> core::task::Poll<Option<Self::Item>> {
        let mut this = self.project();
        loop {
            if *this.done {
                return Ready(None);
            }
            if let Some(next) = this.queue.pop_front() {
                return Ready(Some(Ok(next)));
            }

            if this.encryptor.is_none() {
                *this.done = true;
                return Ready(None);
            }
            let mut encryptor = this.encryptor.take().unwrap();
            match this.stream.as_mut().try_poll_next(cx) {
                Ready(data) => match data {
                    Some(resp) => match resp {
                        Ok(data) => {
                            let result = encryptor.update(Aad(this.aad.as_ref()), data.as_ref());
                            match result {
                                Err(e) => {
                                    *this.done = true;
                                    return Ready(Some(Err(e.into())));
                                }
                                Ok(_) => {
                                    if let Some(ciphertext) = encryptor.next() {
                                        this.queue.push_back(ciphertext);
                                    }
                                    this.encryptor.replace(encryptor);
                                }
                            }
                        }
                        Err(err) => {
                            *this.done = true;
                            return Ready(Some(Err(EncryptTryStreamError::Upstream(err))));
                        }
                    },
                    None => {
                        let result = encryptor.finalize(Aad(this.aad.as_ref()));
                        match result {
                            Err(e) => {
                                *this.done = true;
                                return Ready(Some(Err(e.into())));
                            }
                            Ok(iter) => {
                                this.queue.extend(iter);
                            }
                        }
                    }
                },
                Pending => return Pending,
            }
        }
    }
}

#[pin_project]
pub struct DecryptTryStream<S, C, A>
where
    S: TryStream + Sized,
    C: AsRef<Aead> + Send + Sync,
    A: AsRef<[u8]> + Send + Sync,
{
    #[pin]
    stream: S,
    decryptor: Option<Decryptor<C, Vec<u8>>>,
    queue: VecDeque<Vec<u8>>,
    aad: Aad<A>,
    done: bool,
}

impl<S, C, A> DecryptTryStream<S, C, A>
where
    S: TryStream + Sized,
    C: AsRef<Aead> + Send + Sync,
    A: AsRef<[u8]> + Send + Sync,
{
    pub fn new(stream: S, cipher: C, aad: Aad<A>) -> Self {
        let decryptor = Decryptor::new(cipher, vec![]);
        Self {
            aad,
            stream,
            decryptor: Some(decryptor),
            queue: VecDeque::new(),
            done: false,
        }
    }
}

impl<S, C, D> Stream for DecryptTryStream<S, C, D>
where
    S: TryStream,
    S::Ok: AsRef<[u8]>,
    S::Error: Send + Sync,
    C: AsRef<Aead> + Send + Sync,
    D: AsRef<[u8]> + Send + Sync,
{
    type Item = Result<Vec<u8>, DecryptStreamError<S::Error>>;
    fn poll_next(
        self: core::pin::Pin<&mut Self>,
        cx: &mut core::task::Context<'_>,
    ) -> core::task::Poll<Option<Self::Item>> {
        let mut this = self.project();
        loop {
            if *this.done {
                return Ready(None);
            }
            if let Some(next) = this.queue.pop_front() {
                return Ready(Some(Ok(next)));
            }
            if this.decryptor.is_none() {
                return Ready(None);
            }
            let mut decryptor = this.decryptor.take().unwrap();

            match this.stream.as_mut().try_poll_next(cx) {
                Ready(resp) => match resp {
                    Some(result) => match result {
                        Ok(data) => {
                            decryptor.update(Aad(this.aad.as_ref()), data.as_ref())?;
                            match decryptor.next(Aad(this.aad.as_ref())) {
                                Err(e) => {
                                    *this.done = true;
                                    return Ready(Some(Err(e.into())));
                                }
                                Ok(Some(plaintext)) => {
                                    this.queue.push_back(plaintext);
                                    this.decryptor.replace(decryptor);
                                }
                                Ok(None) => {
                                    this.decryptor.replace(decryptor);
                                }
                            }
                        }
                        Err(err) => {
                            *this.done = true;
                            return Ready(Some(Err(DecryptStreamError::Upstream(err))));
                        }
                    },
                    None => {
                        let result = decryptor.finalize(Aad(this.aad.as_ref()));
                        match result {
                            Err(e) => {
                                *this.done = true;
                                return Ready(Some(Err(e.into())));
                            }
                            Ok(iter) => {
                                this.queue.extend(iter);
                            }
                        }
                    }
                },
                Pending => return Pending,
            }
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::aead::Algorithm;
    use alloc::string::String;
    use futures::{stream, StreamExt, TryStreamExt};

    fn to_try_stream(d: Vec<u8>) -> Result<Vec<u8>, String> {
        Ok(d)
    }
    #[cfg(feature = "std")]
    #[tokio::test]
    async fn test_stream_roundtrip() {
        let data_stream = stream::iter(vec![
            Vec::from(&b"hello"[..]),
            Vec::from(&b" "[..]),
            Vec::from(&b"world"[..]),
        ])
        .map(to_try_stream);

        let aead = Aead::new(Algorithm::Aes256Gcm, None);
        let encrypt_stream =
            data_stream.encrypt(Segment::FourKilobytes, Aad::empty(), aead.clone());
        let ciphertext: Vec<u8> = encrypt_stream.try_concat().await.unwrap();

        let ciphertext_stream =
            stream::iter(ciphertext.chunks(40).map(Vec::from)).map(to_try_stream);
        let decrypt_stream = ciphertext_stream.decrypt(Aad::empty(), aead);
        let result = decrypt_stream.try_concat().await.unwrap();
        assert_eq!(result, b"hello world");
    }
    #[cfg(feature = "std")]
    #[tokio::test]
    async fn test_stream_with_aad_roundtrip() {
        let mut data = vec![0u8; 5556];
        let rng = SystemRng::new();
        rng.fill(&mut data).unwrap();
        let data_stream = stream::iter(data.chunks(122).map(Vec::from)).map(to_try_stream);
        let algorithm = Algorithm::Aes256Gcm;
        let aead = Aead::new(algorithm, None);
        let encrypt_stream = data_stream.encrypt(
            Segment::FourKilobytes,
            Aad(Vec::from(&b"additional data"[..])),
            aead.clone(),
        );
        let result: Vec<Vec<u8>> = encrypt_stream.try_collect().await.unwrap();

        let ciphertext = result.concat();

        println!("ciphertext: {}", ciphertext.len());
        let ciphertext_stream =
            stream::iter(ciphertext.chunks(40).map(Vec::from)).map(to_try_stream);
        let plaintext: Vec<u8> = ciphertext_stream
            .decrypt(Vec::from(&b"additional data"[..]).into(), aead)
            .try_concat()
            .await
            .unwrap();
        assert_eq!(plaintext, data);
    }
}