async_encrypted_stream/
read_half.rs

1use chacha20poly1305::{
2    aead::{
3        generic_array::ArrayLength,
4        stream::{Decryptor, NonceSize, StreamPrimitive},
5    },
6    AeadInPlace,
7};
8use pin_project_lite::pin_project;
9use std::{ops::Sub, pin::Pin, task::ready};
10
11use tokio::io::{AsyncBufRead, AsyncRead};
12
13use crate::DEFAULT_BUFFER_SIZE;
14
15pin_project! {
16    /// Async Encryption Read Half
17    pub struct ReadHalf<T, U> {
18
19        #[pin]
20        inner: T,
21        decryptor: U,
22        buffer: Vec<u8>,
23        pos: usize,
24        cap: usize
25    }
26}
27
28impl<T, A, S> ReadHalf<T, Decryptor<A, S>>
29where
30    S: StreamPrimitive<A>,
31    A: AeadInPlace,
32    A::NonceSize: Sub<<S as StreamPrimitive<A>>::NonceOverhead>,
33    NonceSize<A, S>: ArrayLength<u8>,
34{
35    pub fn new(inner: T, decryptor: Decryptor<A, S>) -> Self {
36        Self::with_capacity(inner, decryptor, DEFAULT_BUFFER_SIZE)
37    }
38    pub fn with_capacity(inner: T, decryptor: Decryptor<A, S>, size: usize) -> Self {
39        Self {
40            inner,
41            decryptor,
42            buffer: vec![0u8; size],
43            pos: 0,
44            cap: 0,
45        }
46    }
47
48    /// Produce a value if there is enough data in the internal buffer
49    ///
50    /// When a value is produced, it will advance the buffer to the position for the next value.
51    fn produce(mut self: Pin<&mut Self>) -> std::io::Result<Option<Vec<u8>>> {
52        if self.cap <= self.pos {
53            return Ok(None);
54        }
55
56        // Producing a value is a relatively simple operation.
57        // Read 4 bytes from the buffer and cast to a u32 as the length of the message.
58        // If there is enough bytes in the buffer, read the bytes and decrypt the message.
59        //
60        // Then advance the buffer to the next position (4 + length)
61        //
62        // If there isn't enough bytes to produce a message, just return None
63
64        let mut length_bytes = [0u8; 4];
65        length_bytes.copy_from_slice(&self.buffer[self.pos..self.pos + 4]);
66        let length = u32::from_le_bytes(length_bytes) as usize;
67
68        let me = self.as_mut().project();
69        if *me.cap >= *me.pos + length + 4 {
70            let decrypted = me
71                .decryptor
72                .decrypt_next(&me.buffer[*me.pos + 4..*me.pos + 4 + length])
73                .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidData, err))?;
74
75            *me.pos += 4 + length;
76            if *me.pos == *me.cap {
77                *me.pos = 0;
78                *me.cap = 0;
79            }
80
81            Ok(Some(decrypted))
82        } else {
83            self.adjust_buffer(length + 4);
84            Ok(None)
85        }
86    }
87
88    /// Adjusts the buffer to fit the next full message.
89    ///
90    /// When the buffer reach a position where the length of the message is greater than the buffer
91    /// available capacity, it is necessary to reset the buffer position to 0 and move the bytes
92    /// available to the beginning of the buffer, freeing buffer capacity to be filled.
93    ///
94    /// It is also possible that the message length is bigger than the buffer full size, in this
95    /// case the buffer will be resized to double it's full capacity. This operation should not
96    /// be necessary because the writter is limited to write 1024 bytes long messages
97    fn adjust_buffer(self: Pin<&mut Self>, desired_additional: usize) {
98        let me = self.project();
99        if *me.cap + desired_additional >= me.buffer.len() && *me.pos > 0 {
100            me.buffer.copy_within(*me.pos..*me.cap, 0);
101            *me.cap -= *me.pos;
102            *me.pos = 0;
103        }
104
105        if *me.pos + desired_additional > me.buffer.len() {
106            me.buffer.resize(me.buffer.len() * 2, 0);
107        }
108    }
109
110    /// Return the contents of the internal buffer at the current position, for diagnostic
111    /// purposes.
112    ///
113    /// For each message available in the buffer, the first 4 bytes are the message length encoded
114    /// as a **little endian** u32. The end of the buffer may contain incomplete data.
115    pub fn buffer(&self) -> &[u8] {
116        &self.buffer[self.pos..]
117    }
118}
119
120impl<T, A, S> AsyncRead for ReadHalf<T, Decryptor<A, S>>
121where
122    T: AsyncRead,
123    S: StreamPrimitive<A>,
124    A: AeadInPlace,
125    A::NonceSize: Sub<<S as StreamPrimitive<A>>::NonceOverhead>,
126    NonceSize<A, S>: ArrayLength<u8>,
127{
128    /// The poll read simply tries to produce a value from the internal buffer.
129    /// If no value is produced, it then tries to poll more bytes from the inner reader
130    ///
131    /// This function may return a [std::io::ErrorKind::InvalidData] if it is not possible to decrypt
132    /// the message, in this case, further read attempts will always produce the same error.
133    fn poll_read(
134        mut self: Pin<&mut Self>,
135        cx: &mut std::task::Context<'_>,
136        buf: &mut tokio::io::ReadBuf<'_>,
137    ) -> std::task::Poll<std::io::Result<()>> {
138        loop {
139            if let Some(decrypted) = self.as_mut().produce()? {
140                if decrypted.len() > buf.remaining() {
141                    Err(std::io::Error::new(
142                        std::io::ErrorKind::OutOfMemory,
143                        "Decrypted value exceeds buffer capacity",
144                    ))?;
145                }
146
147                buf.put_slice(&decrypted);
148                return std::task::Poll::Ready(Ok(()));
149            }
150
151            if ready!(self.as_mut().poll_fill_buf(cx))?.is_empty() {
152                return std::task::Poll::Ready(Ok(()));
153            }
154        }
155    }
156}
157
158impl<R: AsyncRead, A, S> tokio::io::AsyncBufRead for ReadHalf<R, Decryptor<A, S>>
159where
160    S: StreamPrimitive<A>,
161    A: AeadInPlace,
162    A::NonceSize: Sub<<S as StreamPrimitive<A>>::NonceOverhead>,
163    NonceSize<A, S>: ArrayLength<u8>,
164{
165    fn poll_fill_buf(
166        self: Pin<&mut Self>,
167        cx: &mut std::task::Context<'_>,
168    ) -> std::task::Poll<std::io::Result<&[u8]>> {
169        let me = self.project();
170
171        let mut buf = tokio::io::ReadBuf::new(&mut me.buffer[*me.cap..]);
172        ready!(me.inner.poll_read(cx, &mut buf))?;
173        if !buf.filled().is_empty() {
174            *me.cap += buf.filled().len();
175        }
176
177        std::task::Poll::Ready(Ok(&me.buffer[*me.pos..*me.cap]))
178    }
179
180    fn consume(self: Pin<&mut Self>, amt: usize) {
181        let me = self.project();
182        *me.pos += amt;
183        if *me.pos >= *me.cap {
184            *me.pos = 0;
185            *me.cap = 0;
186        }
187    }
188}
189
190#[cfg(test)]
191mod tests {
192    use std::{assert_eq, time::Duration};
193
194    use chacha20poly1305::{aead::stream::EncryptorLE31, KeyInit, XChaCha20Poly1305};
195    use tokio::io::{AsyncReadExt, AsyncWriteExt};
196
197    use crate::get_key;
198
199    use super::*;
200
201    #[tokio::test]
202    pub async fn test_crypto_stream_read_half() {
203        let key: [u8; 32] = get_key("key", "group");
204        let start_nonce = [0u8; 20];
205
206        let (rx, mut tx) = tokio::io::duplex(100);
207
208        tokio::spawn(async move {
209            let encrypted_content = {
210                let mut encryptor: EncryptorLE31<XChaCha20Poly1305> =
211                    chacha20poly1305::aead::stream::EncryptorLE31::from_aead(
212                        XChaCha20Poly1305::new(key.as_ref().into()),
213                        start_nonce.as_ref().into(),
214                    );
215
216                let mut expected = Vec::new();
217
218                for data in ["some content", "some other content", "even more content"] {
219                    let mut encrypted = encryptor.encrypt_next(data.as_bytes()).unwrap();
220                    expected.extend((encrypted.len() as u32).to_le_bytes());
221                    expected.append(&mut encrypted);
222                }
223
224                expected
225            };
226
227            for chunk in encrypted_content.chunks(10) {
228                let _ = tx.write(chunk).await;
229                tokio::time::sleep(Duration::from_millis(20)).await;
230            }
231        });
232
233        tokio::time::sleep(Duration::from_millis(20)).await;
234
235        let decryptor = chacha20poly1305::aead::stream::DecryptorLE31::from_aead(
236            XChaCha20Poly1305::new(key.as_ref().into()),
237            start_nonce.as_ref().into(),
238        );
239        let mut reader = ReadHalf::new(rx, decryptor);
240
241        let mut plain_content = String::new();
242        let _ = reader.read_to_string(&mut plain_content).await;
243
244        assert_eq!(
245            plain_content,
246            "some contentsome other contenteven more content"
247        );
248    }
249
250    #[tokio::test]
251    pub async fn test_read_invalid_data() {
252        let key: [u8; 32] = get_key("key", "group");
253        let start_nonce = [0u8; 20];
254
255        let (rx, _tx) = tokio::io::duplex(100);
256
257        let decryptor = chacha20poly1305::aead::stream::DecryptorLE31::from_aead(
258            XChaCha20Poly1305::new(key.as_ref().into()),
259            start_nonce.as_ref().into(),
260        );
261        let mut reader = ReadHalf::new(rx, decryptor);
262        let mut reader_data = Vec::from_iter(10u32.to_le_bytes());
263        reader_data.extend_from_slice(&[0u8; 20]);
264
265        reader.cap = reader_data.len();
266        reader.buffer = reader_data;
267
268        let mut buf = [0u8; 1024];
269
270        assert!(reader.read(&mut buf).await.is_err());
271        assert!(reader.read(&mut buf).await.is_err());
272    }
273}