fractal_storage_client/stream/
chacha20.rs1use bytes::{Buf, BufMut, Bytes, BytesMut};
2use chacha20::cipher::{NewCipher, StreamCipher};
3use chacha20::{Key, XChaCha20, XNonce};
4use futures::task::Context;
5use futures::task::Poll;
6use futures::Stream;
7use log::debug;
8use rand_core::{OsRng, RngCore};
9use std::error::Error as StdError;
10use std::pin::Pin;
11
12#[derive(Clone, Copy, Debug, PartialEq, Eq)]
13enum EncryptionStreamState {
14 Start,
15 Stream,
16 Done,
17 Error,
18}
19
20pub struct EncryptionStream<E: StdError> {
21 stream: Pin<Box<dyn Stream<Item = Result<Bytes, E>> + Send + Sync>>,
22 state: EncryptionStreamState,
23 nonce: XNonce,
24 crypt: XChaCha20,
25}
26
27impl<E: StdError> EncryptionStream<E> {
28 pub fn new<S: Stream<Item = Result<Bytes, E>> + Send + Sync + 'static>(
29 stream: S,
30 key: &Key,
31 ) -> Self {
32 let mut nonce = [0u8; 24];
34 OsRng.fill_bytes(&mut nonce);
35 let nonce = XNonce::from_slice(&nonce);
36
37 EncryptionStream {
38 state: EncryptionStreamState::Start,
39 stream: Box::pin(stream),
40 nonce: nonce.clone(),
41 crypt: XChaCha20::new(key, nonce),
42 }
43 }
44}
45
46impl<E: StdError> Stream for EncryptionStream<E> {
47 type Item = Result<Bytes, E>;
48
49 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
50 use EncryptionStreamState::*;
51 match self.state.clone() {
52 Start => {
53 let mut nonce: BytesMut = self.nonce.as_slice().into();
54 self.state = Stream;
55 match Pin::new(&mut self.stream).poll_next(cx) {
56 Poll::Ready(Some(Ok(bytes))) => {
57 let mut bytes = bytes.as_ref().to_vec();
58 self.crypt.apply_keystream(&mut bytes);
59 nonce.extend_from_slice(&bytes[..]);
60 }
61 Poll::Ready(Some(Err(error))) => {
62 self.state = Error;
63 return Poll::Ready(Some(Err(error)));
64 }
65 Poll::Ready(None) => {
66 self.state = Done;
67 }
68 Poll::Pending => {}
69 }
70 debug!("Sending nonce");
71 Poll::Ready(Some(Ok(nonce.freeze())))
72 }
73 Stream => match Pin::new(&mut self.stream).poll_next(cx) {
74 Poll::Ready(Some(Err(error))) => {
75 debug!("Read error from stream: {error:?}");
76 self.state = Error;
77 Poll::Ready(Some(Err(error)))
78 }
79 done @ Poll::Ready(None) => {
80 debug!("Stream closed");
81 self.state = Done;
82 done
83 }
84 Poll::Ready(Some(Ok(bytes))) => {
85 debug!("Read {} bytes", bytes.len());
86 let mut bytes = bytes.as_ref().to_vec();
87 self.crypt.apply_keystream(&mut bytes);
88 Poll::Ready(Some(Ok(Bytes::copy_from_slice(&bytes))))
89 }
90 other => other,
91 },
92 Done | Error => Poll::Ready(None),
93 }
94 }
95}
96
97enum DecryptionStreamState {
98 Start(Key, BytesMut),
99 Stream(XChaCha20),
100 Done,
101 Error,
102}
103
104pub struct DecryptionStream<E: StdError> {
105 stream: Pin<Box<dyn Stream<Item = Result<Bytes, E>> + Send>>,
106 state: DecryptionStreamState,
107}
108
109impl<E: StdError> DecryptionStream<E> {
110 pub fn new<S: Stream<Item = Result<Bytes, E>> + Send + 'static>(stream: S, key: &Key) -> Self {
111 DecryptionStream {
112 stream: Box::pin(stream),
113 state: DecryptionStreamState::Start(key.clone(), BytesMut::with_capacity(24)),
114 }
115 }
116}
117
118impl<E: StdError> Stream for DecryptionStream<E> {
119 type Item = Result<Bytes, E>;
120
121 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
122 use DecryptionStreamState::*;
123 let result = Pin::new(&mut self.stream).poll_next(cx);
124 match &mut self.state {
125 Start(key, nonce) => match result {
126 Poll::Ready(Some(Ok(mut bytes))) => {
127 debug!("Read {} bytes raw data", bytes.len());
128 let nonce_data = bytes.split_to((24 - nonce.len()).min(bytes.len()));
129 debug!("Putting nonce data");
130 nonce.put(nonce_data);
131 if nonce.len() == 24 {
132 debug!("Got nonce");
133 let nonce = XNonce::from_slice(&nonce);
134 let mut crypter = XChaCha20::new(&key, &nonce);
135 let mut bytes: BytesMut = bytes.chunk().into();
136 debug!("Decrypting {} bytes", bytes.len());
137 crypter.apply_keystream(&mut bytes);
138 self.state = DecryptionStreamState::Stream(crypter);
139 Poll::Ready(Some(Ok(bytes.freeze())))
140 } else {
141 Poll::Ready(Some(Ok(Bytes::new())))
142 }
143 }
144 error @ Poll::Ready(Some(Err(_))) => {
145 self.state = DecryptionStreamState::Error;
146 error
147 }
148 done @ Poll::Ready(None) => {
149 self.state = DecryptionStreamState::Done;
150 done
151 }
152 result => result,
153 },
154 Stream(xchacha) => match result {
155 Poll::Ready(Some(Ok(bytes))) => {
156 debug!("Read {} bytes raw data", bytes.len());
157 let mut bytes: BytesMut = bytes.chunk().into();
158 xchacha.apply_keystream(&mut bytes);
159 Poll::Ready(Some(Ok(bytes.freeze())))
160 }
161 error @ Poll::Ready(Some(Err(_))) => {
162 self.state = DecryptionStreamState::Error;
163 error
164 }
165 done @ Poll::Ready(None) => {
166 self.state = DecryptionStreamState::Done;
167 done
168 }
169 result => result,
170 },
171 Done | Error => Poll::Ready(None),
172 }
173 }
174}
175
176#[cfg(test)]
177#[tokio::test]
178async fn encrypt_empty_stream() {
179 use futures::StreamExt;
180 let key = Key::from_slice(b"abcdefghijklmnopqrstuvwxyz012345");
181 let stream = futures::stream::iter(vec![]);
182 let mut crypt_stream = EncryptionStream::<std::io::Error>::new(stream, key);
183 assert_eq!(crypt_stream.state, EncryptionStreamState::Start);
184
185 let result = crypt_stream.next().await.unwrap();
186 assert_eq!(result.unwrap(), crypt_stream.nonce.as_slice());
187 assert_eq!(crypt_stream.state, EncryptionStreamState::Done);
188
189 assert!(crypt_stream.next().await.is_none());
190 assert_eq!(crypt_stream.state, EncryptionStreamState::Done);
191
192 assert!(crypt_stream.next().await.is_none());
193 assert_eq!(crypt_stream.state, EncryptionStreamState::Done);
194}
195
196#[cfg(test)]
197#[tokio::test]
198async fn encrypt_single_stream() {
199 use futures::StreamExt;
200 let key = Key::from_slice(b"abcdefghijklmnopqrstuvwxyz012345");
201 let stream = futures::stream::iter(vec![Ok(Bytes::copy_from_slice(b"hello"))]);
202 let mut crypt_stream = EncryptionStream::<std::io::Error>::new(stream, key);
203 assert_eq!(crypt_stream.state, EncryptionStreamState::Start);
204
205 let _result = crypt_stream.next().await.unwrap();
206 assert!(crypt_stream.next().await.is_none());
214 assert_eq!(crypt_stream.state, EncryptionStreamState::Done);
215}
216
217#[cfg(test)]
218#[tokio::test]
219async fn encrypt_multi_stream() {
220 use futures::StreamExt;
221 let key = Key::from_slice(b"abcdefghijklmnopqrstuvwxyz012345");
222 let stream = futures::stream::iter(vec![
223 Ok(Bytes::copy_from_slice(b"hello")),
224 Ok(Bytes::copy_from_slice(b"there")),
225 Ok(Bytes::copy_from_slice(b"world!")),
226 ]);
227 let mut crypt_stream = EncryptionStream::<std::io::Error>::new(stream, key);
228 assert_eq!(crypt_stream.state, EncryptionStreamState::Start);
229
230 let result = crypt_stream.next().await.unwrap();
231 assert_eq!(
232 result.unwrap().len(),
233 crypt_stream.nonce.as_slice().len() + 5
234 );
235 assert_eq!(crypt_stream.state, EncryptionStreamState::Stream);
236
237 let result = crypt_stream.next().await.unwrap();
238 assert_eq!(result.unwrap().len(), 5);
239 assert_eq!(crypt_stream.state, EncryptionStreamState::Stream);
240
241 let result = crypt_stream.next().await.unwrap();
242 assert_eq!(result.unwrap().len(), 6);
243 assert_eq!(crypt_stream.state, EncryptionStreamState::Stream);
244
245 assert!(crypt_stream.next().await.is_none());
246 assert_eq!(crypt_stream.state, EncryptionStreamState::Done);
247}
248
249#[cfg(test)]
250#[tokio::test]
251async fn encrypt_error_stream() {
252 use futures::StreamExt;
253 let key = Key::from_slice(b"abcdefghijklmnopqrstuvwxyz012345");
254 let stream = futures::stream::iter(vec![
255 Ok(Bytes::copy_from_slice(b"hello")),
256 Err(std::io::Error::new(
257 std::io::ErrorKind::InvalidData,
258 "error",
259 )),
260 Ok(Bytes::copy_from_slice(b"world!")),
261 ]);
262 let mut crypt_stream = EncryptionStream::<std::io::Error>::new(stream, key);
263 assert_eq!(crypt_stream.state, EncryptionStreamState::Start);
264
265 let result = crypt_stream.next().await.unwrap();
266 assert_eq!(
267 result.unwrap().len(),
268 crypt_stream.nonce.as_slice().len() + 5
269 );
270 assert_eq!(crypt_stream.state, EncryptionStreamState::Stream);
271
272 let result = crypt_stream.next().await.unwrap();
273 assert!(result.is_err());
274 assert_eq!(crypt_stream.state, EncryptionStreamState::Error);
275
276 assert!(crypt_stream.next().await.is_none());
277 assert_eq!(crypt_stream.state, EncryptionStreamState::Error);
278}
279
280#[cfg(test)]
281#[tokio::test]
282async fn decrypt_empty_stream() {
283 use futures::StreamExt;
284 let key = Key::from_slice(b"abcdefghijklmnopqrstuvwxyz012345");
285 let nonce: Bytes = "s91hd9v0-dk2ldlv;as920di".into();
286 let stream = futures::stream::iter(vec![Ok(nonce)]);
287 let mut crypt_stream = DecryptionStream::<std::io::Error>::new(stream, key);
288
289 let result = crypt_stream.next().await.unwrap();
290 assert_eq!(result.unwrap(), Bytes::new());
291 assert!(crypt_stream.next().await.is_none());
292}
293
294#[cfg(test)]
295#[tokio::test]
296async fn decrypt_empty_stream2() {
297 use futures::StreamExt;
298 let key = Key::from_slice(b"abcdefghijklmnopqrstuvwxyz012345");
299 let nonce1: Bytes = "s91hd9v0-dk".into();
300 let nonce2: Bytes = "2ldlv;as920di".into();
301 let stream = futures::stream::iter(vec![Ok(nonce1), Ok(nonce2)]);
302 let mut crypt_stream = DecryptionStream::<std::io::Error>::new(stream, key);
303
304 let result = crypt_stream.next().await.unwrap();
305 assert_eq!(result.unwrap(), Bytes::new());
306
307 let result = crypt_stream.next().await.unwrap();
308 assert_eq!(result.unwrap(), Bytes::new());
309
310 assert!(crypt_stream.next().await.is_none());
311}
312
313#[cfg(test)]
314#[tokio::test]
315async fn endtoend_empty_stream() {
316 use futures::StreamExt;
317 let key = Key::from_slice(b"abcdefghijklmnopqrstuvwxyz012345");
318 let stream = futures::stream::iter(vec![]);
319 let stream = EncryptionStream::<std::io::Error>::new(stream, key);
320 let mut stream = DecryptionStream::<std::io::Error>::new(stream, key);
321
322 let result = stream.next().await.unwrap();
323 assert_eq!(result.unwrap(), Bytes::new());
324
325 assert!(stream.next().await.is_none());
326}
327
328#[cfg(test)]
329#[tokio::test]
330async fn endtoend_single_stream() {
331 use futures::StreamExt;
332 let key = Key::from_slice(b"abcdefghijklmnopqrstuvwxyz012345");
333 let data: Bytes = "hello, world!".into();
334 let stream = futures::stream::iter(vec![Ok(data.clone())]);
335 let stream = EncryptionStream::<std::io::Error>::new(stream, key);
336 let mut stream = DecryptionStream::<std::io::Error>::new(stream, key);
337
338 let result = stream.next().await.unwrap();
339 assert_eq!(result.unwrap(), data);
340
341 assert!(stream.next().await.is_none());
342}
343
344#[cfg(test)]
345#[tokio::test]
346async fn endtoend_multi_stream() {
347 use futures::StreamExt;
348 let key = Key::from_slice(b"abcdefghijklmnopqrstuvwxyz012345");
349 let data1: Bytes = "hello, world!".into();
350 let data2: Bytes = "this is an example".into();
351 let stream = futures::stream::iter(vec![Ok(data1.clone()), Ok(data2.clone())]);
352 let stream = EncryptionStream::<std::io::Error>::new(stream, key);
353 let mut stream = DecryptionStream::<std::io::Error>::new(stream, key);
354
355 let result = stream.next().await.unwrap();
356 assert_eq!(result.unwrap(), data1);
357
358 let result = stream.next().await.unwrap();
359 assert_eq!(result.unwrap(), data2);
360
361 assert!(stream.next().await.is_none());
362}