async_encrypted_stream/
write_half.rs1use bytes::{Buf, BufMut, BytesMut};
2use chacha20poly1305::{
3 aead::{
4 generic_array::ArrayLength,
5 stream::{Encryptor, NonceSize, StreamPrimitive},
6 },
7 AeadInPlace,
8};
9
10use std::{
11 ops::Sub,
12 pin::Pin,
13 task::{ready, Poll},
14};
15
16use tokio::io::AsyncWrite;
17
18use crate::{DEFAULT_BUFFER_SIZE, DEFAULT_CHUNK_SIZE};
19
20pin_project_lite::pin_project! {
21 pub struct WriteHalf<T, U> {
26 #[pin]
27 inner: T,
28 encryptor: U,
29 buffer: bytes::BytesMut,
30 chunk_size: usize
31 }
32}
33
34impl<T, A, S> WriteHalf<T, Encryptor<A, S>>
35where
36 T: AsyncWrite,
37 S: StreamPrimitive<A>,
38 A: AeadInPlace,
39 A::NonceSize: Sub<<S as StreamPrimitive<A>>::NonceOverhead>,
40 NonceSize<A, S>: ArrayLength<u8>,
41{
42 pub fn new(inner: T, encryptor: Encryptor<A, S>) -> Self {
43 Self::with_capacity(inner, encryptor, DEFAULT_BUFFER_SIZE, DEFAULT_CHUNK_SIZE)
44 }
45
46 pub fn with_capacity(
47 inner: T,
48 encryptor: Encryptor<A, S>,
49 size: usize,
50 chunk_size: usize,
51 ) -> Self {
52 Self {
53 inner,
54 encryptor,
55 buffer: BytesMut::with_capacity(size),
56 chunk_size,
57 }
58 }
59
60 fn get_encrypted(&mut self, buf: &[u8]) -> std::io::Result<Vec<u8>> {
67 let mut encrypted = self
68 .encryptor
69 .encrypt_next(buf)
70 .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))?;
71
72 let len = (encrypted.len() as u32).to_le_bytes();
73 let mut buf = Vec::with_capacity(encrypted.len() + std::mem::size_of::<u32>());
74 buf.extend_from_slice(&len);
75 buf.append(&mut encrypted);
76
77 Ok(buf)
78 }
79
80 fn flush_buf(
86 self: Pin<&mut Self>,
87 cx: &mut std::task::Context<'_>,
88 ) -> Poll<std::io::Result<()>> {
89 let mut me = self.project();
90 while me.buffer.has_remaining() {
91 match ready!(me.inner.as_mut().poll_write(cx, &me.buffer[..])) {
92 Ok(0) => {
93 return Poll::Ready(Err(std::io::Error::new(
94 std::io::ErrorKind::WriteZero,
95 "failed to write the buffered data",
96 )));
97 }
98 Ok(n) => me.buffer.advance(n),
99 Err(e) => return Poll::Ready(Err(e)),
100 }
101 }
102
103 Poll::Ready(Ok(()))
104 }
105}
106
107impl<T, A, S> AsyncWrite for WriteHalf<T, Encryptor<A, S>>
108where
109 T: AsyncWrite + Unpin,
110 S: StreamPrimitive<A>,
111 A: AeadInPlace,
112 A::NonceSize: Sub<<S as StreamPrimitive<A>>::NonceOverhead>,
113 NonceSize<A, S>: ArrayLength<u8>,
114{
115 fn poll_write(
132 mut self: Pin<&mut Self>,
133 cx: &mut std::task::Context<'_>,
134 buf: &[u8],
135 ) -> std::task::Poll<Result<usize, std::io::Error>> {
136 if !self.buffer.is_empty() {
137 ready!(self.as_mut().flush_buf(cx))?
138 }
139
140 let mut total_written = 0;
141 for chunk in buf.chunks(self.chunk_size) {
142 let encrypted = self.get_encrypted(chunk)?;
143 total_written += chunk.len();
144
145 let me = self.as_mut().project();
146 match me.inner.poll_write(cx, &encrypted[..]) {
147 Poll::Ready(Ok(written)) => {
148 if written < encrypted.len() {
149 self.buffer.put(&encrypted[written..]);
150 return Poll::Ready(Ok(total_written));
151 }
152 }
153 Poll::Pending | Poll::Ready(Err(..)) => {
154 self.buffer.put(&encrypted[..]);
155 return Poll::Ready(Ok(total_written));
156 }
157 }
158 }
159 Poll::Ready(Ok(buf.len()))
160 }
161
162 fn poll_flush(
163 mut self: Pin<&mut Self>,
164 cx: &mut std::task::Context<'_>,
165 ) -> std::task::Poll<Result<(), std::io::Error>> {
166 ready!(self.as_mut().flush_buf(cx))?;
167 self.project().inner.poll_flush(cx)
168 }
169
170 fn poll_shutdown(
171 self: Pin<&mut Self>,
172 cx: &mut std::task::Context<'_>,
173 ) -> std::task::Poll<Result<(), std::io::Error>> {
174 self.project().inner.poll_shutdown(cx)
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use std::assert_eq;
181
182 use chacha20poly1305::{aead::stream::EncryptorLE31, KeyInit, XChaCha20Poly1305};
183 use tokio::io::AsyncWriteExt;
184
185 use crate::get_key;
186
187 use super::*;
188
189 #[tokio::test]
190 pub async fn test_crypto_stream_write_half() {
191 let key: [u8; 32] = get_key("key", "group");
192 let start_nonce = [0u8; 20];
193
194 let mut encryptor: EncryptorLE31<XChaCha20Poly1305> =
195 chacha20poly1305::aead::stream::EncryptorLE31::from_aead(
196 XChaCha20Poly1305::new(key.as_ref().into()),
197 start_nonce.as_ref().into(),
198 );
199
200 let expected = {
201 let mut encrypted = encryptor.encrypt_next("some content".as_bytes()).unwrap();
202 let mut expected = Vec::new();
203 expected.extend((encrypted.len() as u32).to_le_bytes());
204 expected.append(&mut encrypted);
205
206 expected
207 };
208
209 let mut writer = WriteHalf::new(
210 tokio::io::BufWriter::new(Vec::new()),
211 chacha20poly1305::aead::stream::EncryptorLE31::from_aead(
212 XChaCha20Poly1305::new(key.as_ref().into()),
213 start_nonce.as_ref().into(),
214 ),
215 );
216
217 assert_eq!(
218 writer.write(b"some content").await.unwrap(),
219 "some content".bytes().len()
220 );
221
222 assert_eq!(expected, writer.inner.buffer())
223 }
224}