async_encrypted_stream/
write_half.rs

1use bytes::{Buf, BufMut, BytesMut};
2use chacha20poly1305::{
3    aead::{
4        generic_array::ArrayLength,
5        stream::{Encryptor, NonceSize, StreamPrimitive},
6    },
7    AeadInPlace,
8};
9
10use std::{
11    ops::Sub,
12    pin::Pin,
13    task::{ready, Poll},
14};
15
16use tokio::io::AsyncWrite;
17
18use crate::{DEFAULT_BUFFER_SIZE, DEFAULT_CHUNK_SIZE};
19
20pin_project_lite::pin_project! {
21    /// Async Encryption Write Half.
22    ///
23    /// This struct has an internal buffer to hold encrypted bytes that were not written to the
24    /// inner writter. Under "normal" circunstances, the internal buffer will be seldom used.
25    pub struct WriteHalf<T, U> {
26        #[pin]
27        inner: T,
28        encryptor: U,
29        buffer: bytes::BytesMut,
30        chunk_size: usize
31    }
32}
33
34impl<T, A, S> WriteHalf<T, Encryptor<A, S>>
35where
36    T: AsyncWrite,
37    S: StreamPrimitive<A>,
38    A: AeadInPlace,
39    A::NonceSize: Sub<<S as StreamPrimitive<A>>::NonceOverhead>,
40    NonceSize<A, S>: ArrayLength<u8>,
41{
42    pub fn new(inner: T, encryptor: Encryptor<A, S>) -> Self {
43        Self::with_capacity(inner, encryptor, DEFAULT_BUFFER_SIZE, DEFAULT_CHUNK_SIZE)
44    }
45
46    pub fn with_capacity(
47        inner: T,
48        encryptor: Encryptor<A, S>,
49        size: usize,
50        chunk_size: usize,
51    ) -> Self {
52        Self {
53            inner,
54            encryptor,
55            buffer: BytesMut::with_capacity(size),
56            chunk_size,
57        }
58    }
59
60    /// Encrypts `buf` contents and return a [`Vec<u8>`] with 4 bytes in LE representing the encrypted content
61    /// length and the encrypted contents.
62    ///
63    /// [0, 0, 0, 0, ...]
64    ///
65    /// If the encryption fails, it returns [std::error::ErrorKind::InvalidInput]
66    fn get_encrypted(&mut self, buf: &[u8]) -> std::io::Result<Vec<u8>> {
67        let mut encrypted = self
68            .encryptor
69            .encrypt_next(buf)
70            .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))?;
71
72        let len = (encrypted.len() as u32).to_le_bytes();
73        let mut buf = Vec::with_capacity(encrypted.len() + std::mem::size_of::<u32>());
74        buf.extend_from_slice(&len);
75        buf.append(&mut encrypted);
76
77        Ok(buf)
78    }
79
80    /// Flush the internal buffer into the inner writer. This functions does nothing if the
81    /// internal buffer is empty.   
82    ///
83    /// If the inner writter writes 0 bytes, this function will return an
84    /// [std::io::ErrorKind::WriteZero] error.
85    fn flush_buf(
86        self: Pin<&mut Self>,
87        cx: &mut std::task::Context<'_>,
88    ) -> Poll<std::io::Result<()>> {
89        let mut me = self.project();
90        while me.buffer.has_remaining() {
91            match ready!(me.inner.as_mut().poll_write(cx, &me.buffer[..])) {
92                Ok(0) => {
93                    return Poll::Ready(Err(std::io::Error::new(
94                        std::io::ErrorKind::WriteZero,
95                        "failed to write the buffered data",
96                    )));
97                }
98                Ok(n) => me.buffer.advance(n),
99                Err(e) => return Poll::Ready(Err(e)),
100            }
101        }
102
103        Poll::Ready(Ok(()))
104    }
105}
106
107impl<T, A, S> AsyncWrite for WriteHalf<T, Encryptor<A, S>>
108where
109    T: AsyncWrite + Unpin,
110    S: StreamPrimitive<A>,
111    A: AeadInPlace,
112    A::NonceSize: Sub<<S as StreamPrimitive<A>>::NonceOverhead>,
113    NonceSize<A, S>: ArrayLength<u8>,
114{
115    /// Encrypt `buf` content, write into `self.inner` and returns the number of bytes
116    /// encrypted.
117    ///
118    /// Since tokio runtime will call this function repeatedly with the same contents when
119    /// [Poll::Pending] is returned, this function may return [Poll::Pending] only when
120    /// trying to flush the internal buffer, otherwise it will always return `Poll::Ready(Ok(n))`,
121    /// even if the inner writer fails.
122    ///
123    /// This behavior was adopted to guarantee parity with the reading counterpart,
124    /// the contents of `buf` must be encrypted only once, if the internal writing operation fails,
125    /// the already encrypted contents will be written into the internal buffer instead.
126    ///
127    /// It is guaranteed that `0 <= n <= buf.len()`
128    ///
129    /// Internally, the contents of `buf` will be splitted into chunks of `self.chunk_size` size,
130    /// default to 1024 bytes, to avoid allocating a huge `Vec<u8>` when encrypting larger messages.
131    fn poll_write(
132        mut self: Pin<&mut Self>,
133        cx: &mut std::task::Context<'_>,
134        buf: &[u8],
135    ) -> std::task::Poll<Result<usize, std::io::Error>> {
136        if !self.buffer.is_empty() {
137            ready!(self.as_mut().flush_buf(cx))?
138        }
139
140        let mut total_written = 0;
141        for chunk in buf.chunks(self.chunk_size) {
142            let encrypted = self.get_encrypted(chunk)?;
143            total_written += chunk.len();
144
145            let me = self.as_mut().project();
146            match me.inner.poll_write(cx, &encrypted[..]) {
147                Poll::Ready(Ok(written)) => {
148                    if written < encrypted.len() {
149                        self.buffer.put(&encrypted[written..]);
150                        return Poll::Ready(Ok(total_written));
151                    }
152                }
153                Poll::Pending | Poll::Ready(Err(..)) => {
154                    self.buffer.put(&encrypted[..]);
155                    return Poll::Ready(Ok(total_written));
156                }
157            }
158        }
159        Poll::Ready(Ok(buf.len()))
160    }
161
162    fn poll_flush(
163        mut self: Pin<&mut Self>,
164        cx: &mut std::task::Context<'_>,
165    ) -> std::task::Poll<Result<(), std::io::Error>> {
166        ready!(self.as_mut().flush_buf(cx))?;
167        self.project().inner.poll_flush(cx)
168    }
169
170    fn poll_shutdown(
171        self: Pin<&mut Self>,
172        cx: &mut std::task::Context<'_>,
173    ) -> std::task::Poll<Result<(), std::io::Error>> {
174        self.project().inner.poll_shutdown(cx)
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    use std::assert_eq;
181
182    use chacha20poly1305::{aead::stream::EncryptorLE31, KeyInit, XChaCha20Poly1305};
183    use tokio::io::AsyncWriteExt;
184
185    use crate::get_key;
186
187    use super::*;
188
189    #[tokio::test]
190    pub async fn test_crypto_stream_write_half() {
191        let key: [u8; 32] = get_key("key", "group");
192        let start_nonce = [0u8; 20];
193
194        let mut encryptor: EncryptorLE31<XChaCha20Poly1305> =
195            chacha20poly1305::aead::stream::EncryptorLE31::from_aead(
196                XChaCha20Poly1305::new(key.as_ref().into()),
197                start_nonce.as_ref().into(),
198            );
199
200        let expected = {
201            let mut encrypted = encryptor.encrypt_next("some content".as_bytes()).unwrap();
202            let mut expected = Vec::new();
203            expected.extend((encrypted.len() as u32).to_le_bytes());
204            expected.append(&mut encrypted);
205
206            expected
207        };
208
209        let mut writer = WriteHalf::new(
210            tokio::io::BufWriter::new(Vec::new()),
211            chacha20poly1305::aead::stream::EncryptorLE31::from_aead(
212                XChaCha20Poly1305::new(key.as_ref().into()),
213                start_nonce.as_ref().into(),
214            ),
215        );
216
217        assert_eq!(
218            writer.write(b"some content").await.unwrap(),
219            "some content".bytes().len()
220        );
221
222        assert_eq!(expected, writer.inner.buffer())
223    }
224}