1use chacha20poly1305::{
2 aead::{
3 generic_array::ArrayLength,
4 stream::{Decryptor, NonceSize, StreamPrimitive},
5 },
6 AeadInPlace,
7};
8use pin_project_lite::pin_project;
9use std::{ops::Sub, pin::Pin, task::ready};
10
11use tokio::io::{AsyncBufRead, AsyncRead};
12
13use crate::DEFAULT_BUFFER_SIZE;
14
15pin_project! {
16 pub struct ReadHalf<T, U> {
18
19 #[pin]
20 inner: T,
21 decryptor: U,
22 buffer: Vec<u8>,
23 pos: usize,
24 cap: usize
25 }
26}
27
28impl<T, A, S> ReadHalf<T, Decryptor<A, S>>
29where
30 S: StreamPrimitive<A>,
31 A: AeadInPlace,
32 A::NonceSize: Sub<<S as StreamPrimitive<A>>::NonceOverhead>,
33 NonceSize<A, S>: ArrayLength<u8>,
34{
35 pub fn new(inner: T, decryptor: Decryptor<A, S>) -> Self {
36 Self::with_capacity(inner, decryptor, DEFAULT_BUFFER_SIZE)
37 }
38 pub fn with_capacity(inner: T, decryptor: Decryptor<A, S>, size: usize) -> Self {
39 Self {
40 inner,
41 decryptor,
42 buffer: vec![0u8; size],
43 pos: 0,
44 cap: 0,
45 }
46 }
47
48 fn produce(mut self: Pin<&mut Self>) -> std::io::Result<Option<Vec<u8>>> {
52 if self.cap <= self.pos {
53 return Ok(None);
54 }
55
56 let mut length_bytes = [0u8; 4];
65 length_bytes.copy_from_slice(&self.buffer[self.pos..self.pos + 4]);
66 let length = u32::from_le_bytes(length_bytes) as usize;
67
68 let me = self.as_mut().project();
69 if *me.cap >= *me.pos + length + 4 {
70 let decrypted = me
71 .decryptor
72 .decrypt_next(&me.buffer[*me.pos + 4..*me.pos + 4 + length])
73 .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidData, err))?;
74
75 *me.pos += 4 + length;
76 if *me.pos == *me.cap {
77 *me.pos = 0;
78 *me.cap = 0;
79 }
80
81 Ok(Some(decrypted))
82 } else {
83 self.adjust_buffer(length + 4);
84 Ok(None)
85 }
86 }
87
88 fn adjust_buffer(self: Pin<&mut Self>, desired_additional: usize) {
98 let me = self.project();
99 if *me.cap + desired_additional >= me.buffer.len() && *me.pos > 0 {
100 me.buffer.copy_within(*me.pos..*me.cap, 0);
101 *me.cap -= *me.pos;
102 *me.pos = 0;
103 }
104
105 if *me.pos + desired_additional > me.buffer.len() {
106 me.buffer.resize(me.buffer.len() * 2, 0);
107 }
108 }
109
110 pub fn buffer(&self) -> &[u8] {
116 &self.buffer[self.pos..]
117 }
118}
119
120impl<T, A, S> AsyncRead for ReadHalf<T, Decryptor<A, S>>
121where
122 T: AsyncRead,
123 S: StreamPrimitive<A>,
124 A: AeadInPlace,
125 A::NonceSize: Sub<<S as StreamPrimitive<A>>::NonceOverhead>,
126 NonceSize<A, S>: ArrayLength<u8>,
127{
128 fn poll_read(
134 mut self: Pin<&mut Self>,
135 cx: &mut std::task::Context<'_>,
136 buf: &mut tokio::io::ReadBuf<'_>,
137 ) -> std::task::Poll<std::io::Result<()>> {
138 loop {
139 if let Some(decrypted) = self.as_mut().produce()? {
140 if decrypted.len() > buf.remaining() {
141 Err(std::io::Error::new(
142 std::io::ErrorKind::OutOfMemory,
143 "Decrypted value exceeds buffer capacity",
144 ))?;
145 }
146
147 buf.put_slice(&decrypted);
148 return std::task::Poll::Ready(Ok(()));
149 }
150
151 if ready!(self.as_mut().poll_fill_buf(cx))?.is_empty() {
152 return std::task::Poll::Ready(Ok(()));
153 }
154 }
155 }
156}
157
158impl<R: AsyncRead, A, S> tokio::io::AsyncBufRead for ReadHalf<R, Decryptor<A, S>>
159where
160 S: StreamPrimitive<A>,
161 A: AeadInPlace,
162 A::NonceSize: Sub<<S as StreamPrimitive<A>>::NonceOverhead>,
163 NonceSize<A, S>: ArrayLength<u8>,
164{
165 fn poll_fill_buf(
166 self: Pin<&mut Self>,
167 cx: &mut std::task::Context<'_>,
168 ) -> std::task::Poll<std::io::Result<&[u8]>> {
169 let me = self.project();
170
171 let mut buf = tokio::io::ReadBuf::new(&mut me.buffer[*me.cap..]);
172 ready!(me.inner.poll_read(cx, &mut buf))?;
173 if !buf.filled().is_empty() {
174 *me.cap += buf.filled().len();
175 }
176
177 std::task::Poll::Ready(Ok(&me.buffer[*me.pos..*me.cap]))
178 }
179
180 fn consume(self: Pin<&mut Self>, amt: usize) {
181 let me = self.project();
182 *me.pos += amt;
183 if *me.pos >= *me.cap {
184 *me.pos = 0;
185 *me.cap = 0;
186 }
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use std::{assert_eq, time::Duration};
193
194 use chacha20poly1305::{aead::stream::EncryptorLE31, KeyInit, XChaCha20Poly1305};
195 use tokio::io::{AsyncReadExt, AsyncWriteExt};
196
197 use crate::get_key;
198
199 use super::*;
200
201 #[tokio::test]
202 pub async fn test_crypto_stream_read_half() {
203 let key: [u8; 32] = get_key("key", "group");
204 let start_nonce = [0u8; 20];
205
206 let (rx, mut tx) = tokio::io::duplex(100);
207
208 tokio::spawn(async move {
209 let encrypted_content = {
210 let mut encryptor: EncryptorLE31<XChaCha20Poly1305> =
211 chacha20poly1305::aead::stream::EncryptorLE31::from_aead(
212 XChaCha20Poly1305::new(key.as_ref().into()),
213 start_nonce.as_ref().into(),
214 );
215
216 let mut expected = Vec::new();
217
218 for data in ["some content", "some other content", "even more content"] {
219 let mut encrypted = encryptor.encrypt_next(data.as_bytes()).unwrap();
220 expected.extend((encrypted.len() as u32).to_le_bytes());
221 expected.append(&mut encrypted);
222 }
223
224 expected
225 };
226
227 for chunk in encrypted_content.chunks(10) {
228 let _ = tx.write(chunk).await;
229 tokio::time::sleep(Duration::from_millis(20)).await;
230 }
231 });
232
233 tokio::time::sleep(Duration::from_millis(20)).await;
234
235 let decryptor = chacha20poly1305::aead::stream::DecryptorLE31::from_aead(
236 XChaCha20Poly1305::new(key.as_ref().into()),
237 start_nonce.as_ref().into(),
238 );
239 let mut reader = ReadHalf::new(rx, decryptor);
240
241 let mut plain_content = String::new();
242 let _ = reader.read_to_string(&mut plain_content).await;
243
244 assert_eq!(
245 plain_content,
246 "some contentsome other contenteven more content"
247 );
248 }
249
250 #[tokio::test]
251 pub async fn test_read_invalid_data() {
252 let key: [u8; 32] = get_key("key", "group");
253 let start_nonce = [0u8; 20];
254
255 let (rx, _tx) = tokio::io::duplex(100);
256
257 let decryptor = chacha20poly1305::aead::stream::DecryptorLE31::from_aead(
258 XChaCha20Poly1305::new(key.as_ref().into()),
259 start_nonce.as_ref().into(),
260 );
261 let mut reader = ReadHalf::new(rx, decryptor);
262 let mut reader_data = Vec::from_iter(10u32.to_le_bytes());
263 reader_data.extend_from_slice(&[0u8; 20]);
264
265 reader.cap = reader_data.len();
266 reader.buffer = reader_data;
267
268 let mut buf = [0u8; 1024];
269
270 assert!(reader.read(&mut buf).await.is_err());
271 assert!(reader.read(&mut buf).await.is_err());
272 }
273}