fractal_storage_client/stream/
chacha20.rs

1use bytes::{Buf, BufMut, Bytes, BytesMut};
2use chacha20::cipher::{NewCipher, StreamCipher};
3use chacha20::{Key, XChaCha20, XNonce};
4use futures::task::Context;
5use futures::task::Poll;
6use futures::Stream;
7use log::debug;
8use rand_core::{OsRng, RngCore};
9use std::error::Error as StdError;
10use std::pin::Pin;
11
12#[derive(Clone, Copy, Debug, PartialEq, Eq)]
13enum EncryptionStreamState {
14    Start,
15    Stream,
16    Done,
17    Error,
18}
19
20pub struct EncryptionStream<E: StdError> {
21    stream: Pin<Box<dyn Stream<Item = Result<Bytes, E>> + Send + Sync>>,
22    state: EncryptionStreamState,
23    nonce: XNonce,
24    crypt: XChaCha20,
25}
26
27impl<E: StdError> EncryptionStream<E> {
28    pub fn new<S: Stream<Item = Result<Bytes, E>> + Send + Sync + 'static>(
29        stream: S,
30        key: &Key,
31    ) -> Self {
32        // generate nonce
33        let mut nonce = [0u8; 24];
34        OsRng.fill_bytes(&mut nonce);
35        let nonce = XNonce::from_slice(&nonce);
36
37        EncryptionStream {
38            state: EncryptionStreamState::Start,
39            stream: Box::pin(stream),
40            nonce: nonce.clone(),
41            crypt: XChaCha20::new(key, nonce),
42        }
43    }
44}
45
46impl<E: StdError> Stream for EncryptionStream<E> {
47    type Item = Result<Bytes, E>;
48
49    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
50        use EncryptionStreamState::*;
51        match self.state.clone() {
52            Start => {
53                let mut nonce: BytesMut = self.nonce.as_slice().into();
54                self.state = Stream;
55                match Pin::new(&mut self.stream).poll_next(cx) {
56                    Poll::Ready(Some(Ok(bytes))) => {
57                        let mut bytes = bytes.as_ref().to_vec();
58                        self.crypt.apply_keystream(&mut bytes);
59                        nonce.extend_from_slice(&bytes[..]);
60                    }
61                    Poll::Ready(Some(Err(error))) => {
62                        self.state = Error;
63                        return Poll::Ready(Some(Err(error)));
64                    }
65                    Poll::Ready(None) => {
66                        self.state = Done;
67                    }
68                    Poll::Pending => {}
69                }
70                debug!("Sending nonce");
71                Poll::Ready(Some(Ok(nonce.freeze())))
72            }
73            Stream => match Pin::new(&mut self.stream).poll_next(cx) {
74                Poll::Ready(Some(Err(error))) => {
75                    debug!("Read error from stream: {error:?}");
76                    self.state = Error;
77                    Poll::Ready(Some(Err(error)))
78                }
79                done @ Poll::Ready(None) => {
80                    debug!("Stream closed");
81                    self.state = Done;
82                    done
83                }
84                Poll::Ready(Some(Ok(bytes))) => {
85                    debug!("Read {} bytes", bytes.len());
86                    let mut bytes = bytes.as_ref().to_vec();
87                    self.crypt.apply_keystream(&mut bytes);
88                    Poll::Ready(Some(Ok(Bytes::copy_from_slice(&bytes))))
89                }
90                other => other,
91            },
92            Done | Error => Poll::Ready(None),
93        }
94    }
95}
96
97enum DecryptionStreamState {
98    Start(Key, BytesMut),
99    Stream(XChaCha20),
100    Done,
101    Error,
102}
103
104pub struct DecryptionStream<E: StdError> {
105    stream: Pin<Box<dyn Stream<Item = Result<Bytes, E>> + Send>>,
106    state: DecryptionStreamState,
107}
108
109impl<E: StdError> DecryptionStream<E> {
110    pub fn new<S: Stream<Item = Result<Bytes, E>> + Send + 'static>(stream: S, key: &Key) -> Self {
111        DecryptionStream {
112            stream: Box::pin(stream),
113            state: DecryptionStreamState::Start(key.clone(), BytesMut::with_capacity(24)),
114        }
115    }
116}
117
118impl<E: StdError> Stream for DecryptionStream<E> {
119    type Item = Result<Bytes, E>;
120
121    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
122        use DecryptionStreamState::*;
123        let result = Pin::new(&mut self.stream).poll_next(cx);
124        match &mut self.state {
125            Start(key, nonce) => match result {
126                Poll::Ready(Some(Ok(mut bytes))) => {
127                    debug!("Read {} bytes raw data", bytes.len());
128                    let nonce_data = bytes.split_to((24 - nonce.len()).min(bytes.len()));
129                    debug!("Putting nonce data");
130                    nonce.put(nonce_data);
131                    if nonce.len() == 24 {
132                        debug!("Got nonce");
133                        let nonce = XNonce::from_slice(&nonce);
134                        let mut crypter = XChaCha20::new(&key, &nonce);
135                        let mut bytes: BytesMut = bytes.chunk().into();
136                        debug!("Decrypting {} bytes", bytes.len());
137                        crypter.apply_keystream(&mut bytes);
138                        self.state = DecryptionStreamState::Stream(crypter);
139                        Poll::Ready(Some(Ok(bytes.freeze())))
140                    } else {
141                        Poll::Ready(Some(Ok(Bytes::new())))
142                    }
143                }
144                error @ Poll::Ready(Some(Err(_))) => {
145                    self.state = DecryptionStreamState::Error;
146                    error
147                }
148                done @ Poll::Ready(None) => {
149                    self.state = DecryptionStreamState::Done;
150                    done
151                }
152                result => result,
153            },
154            Stream(xchacha) => match result {
155                Poll::Ready(Some(Ok(bytes))) => {
156                    debug!("Read {} bytes raw data", bytes.len());
157                    let mut bytes: BytesMut = bytes.chunk().into();
158                    xchacha.apply_keystream(&mut bytes);
159                    Poll::Ready(Some(Ok(bytes.freeze())))
160                }
161                error @ Poll::Ready(Some(Err(_))) => {
162                    self.state = DecryptionStreamState::Error;
163                    error
164                }
165                done @ Poll::Ready(None) => {
166                    self.state = DecryptionStreamState::Done;
167                    done
168                }
169                result => result,
170            },
171            Done | Error => Poll::Ready(None),
172        }
173    }
174}
175
176#[cfg(test)]
177#[tokio::test]
178async fn encrypt_empty_stream() {
179    use futures::StreamExt;
180    let key = Key::from_slice(b"abcdefghijklmnopqrstuvwxyz012345");
181    let stream = futures::stream::iter(vec![]);
182    let mut crypt_stream = EncryptionStream::<std::io::Error>::new(stream, key);
183    assert_eq!(crypt_stream.state, EncryptionStreamState::Start);
184
185    let result = crypt_stream.next().await.unwrap();
186    assert_eq!(result.unwrap(), crypt_stream.nonce.as_slice());
187    assert_eq!(crypt_stream.state, EncryptionStreamState::Done);
188
189    assert!(crypt_stream.next().await.is_none());
190    assert_eq!(crypt_stream.state, EncryptionStreamState::Done);
191
192    assert!(crypt_stream.next().await.is_none());
193    assert_eq!(crypt_stream.state, EncryptionStreamState::Done);
194}
195
196#[cfg(test)]
197#[tokio::test]
198async fn encrypt_single_stream() {
199    use futures::StreamExt;
200    let key = Key::from_slice(b"abcdefghijklmnopqrstuvwxyz012345");
201    let stream = futures::stream::iter(vec![Ok(Bytes::copy_from_slice(b"hello"))]);
202    let mut crypt_stream = EncryptionStream::<std::io::Error>::new(stream, key);
203    assert_eq!(crypt_stream.state, EncryptionStreamState::Start);
204
205    let _result = crypt_stream.next().await.unwrap();
206    //assert_eq!(result.unwrap(), crypt_stream.nonce.as_slice());
207    //assert_eq!(crypt_stream.state, EncryptionStreamState::Stream);
208
209    //let result = crypt_stream.next().await.unwrap();
210    //assert_eq!(result.unwrap().len(), 5);
211    //assert_eq!(crypt_stream.state, EncryptionStreamState::Stream);
212
213    assert!(crypt_stream.next().await.is_none());
214    assert_eq!(crypt_stream.state, EncryptionStreamState::Done);
215}
216
217#[cfg(test)]
218#[tokio::test]
219async fn encrypt_multi_stream() {
220    use futures::StreamExt;
221    let key = Key::from_slice(b"abcdefghijklmnopqrstuvwxyz012345");
222    let stream = futures::stream::iter(vec![
223        Ok(Bytes::copy_from_slice(b"hello")),
224        Ok(Bytes::copy_from_slice(b"there")),
225        Ok(Bytes::copy_from_slice(b"world!")),
226    ]);
227    let mut crypt_stream = EncryptionStream::<std::io::Error>::new(stream, key);
228    assert_eq!(crypt_stream.state, EncryptionStreamState::Start);
229
230    let result = crypt_stream.next().await.unwrap();
231    assert_eq!(
232        result.unwrap().len(),
233        crypt_stream.nonce.as_slice().len() + 5
234    );
235    assert_eq!(crypt_stream.state, EncryptionStreamState::Stream);
236
237    let result = crypt_stream.next().await.unwrap();
238    assert_eq!(result.unwrap().len(), 5);
239    assert_eq!(crypt_stream.state, EncryptionStreamState::Stream);
240
241    let result = crypt_stream.next().await.unwrap();
242    assert_eq!(result.unwrap().len(), 6);
243    assert_eq!(crypt_stream.state, EncryptionStreamState::Stream);
244
245    assert!(crypt_stream.next().await.is_none());
246    assert_eq!(crypt_stream.state, EncryptionStreamState::Done);
247}
248
249#[cfg(test)]
250#[tokio::test]
251async fn encrypt_error_stream() {
252    use futures::StreamExt;
253    let key = Key::from_slice(b"abcdefghijklmnopqrstuvwxyz012345");
254    let stream = futures::stream::iter(vec![
255        Ok(Bytes::copy_from_slice(b"hello")),
256        Err(std::io::Error::new(
257            std::io::ErrorKind::InvalidData,
258            "error",
259        )),
260        Ok(Bytes::copy_from_slice(b"world!")),
261    ]);
262    let mut crypt_stream = EncryptionStream::<std::io::Error>::new(stream, key);
263    assert_eq!(crypt_stream.state, EncryptionStreamState::Start);
264
265    let result = crypt_stream.next().await.unwrap();
266    assert_eq!(
267        result.unwrap().len(),
268        crypt_stream.nonce.as_slice().len() + 5
269    );
270    assert_eq!(crypt_stream.state, EncryptionStreamState::Stream);
271
272    let result = crypt_stream.next().await.unwrap();
273    assert!(result.is_err());
274    assert_eq!(crypt_stream.state, EncryptionStreamState::Error);
275
276    assert!(crypt_stream.next().await.is_none());
277    assert_eq!(crypt_stream.state, EncryptionStreamState::Error);
278}
279
280#[cfg(test)]
281#[tokio::test]
282async fn decrypt_empty_stream() {
283    use futures::StreamExt;
284    let key = Key::from_slice(b"abcdefghijklmnopqrstuvwxyz012345");
285    let nonce: Bytes = "s91hd9v0-dk2ldlv;as920di".into();
286    let stream = futures::stream::iter(vec![Ok(nonce)]);
287    let mut crypt_stream = DecryptionStream::<std::io::Error>::new(stream, key);
288
289    let result = crypt_stream.next().await.unwrap();
290    assert_eq!(result.unwrap(), Bytes::new());
291    assert!(crypt_stream.next().await.is_none());
292}
293
294#[cfg(test)]
295#[tokio::test]
296async fn decrypt_empty_stream2() {
297    use futures::StreamExt;
298    let key = Key::from_slice(b"abcdefghijklmnopqrstuvwxyz012345");
299    let nonce1: Bytes = "s91hd9v0-dk".into();
300    let nonce2: Bytes = "2ldlv;as920di".into();
301    let stream = futures::stream::iter(vec![Ok(nonce1), Ok(nonce2)]);
302    let mut crypt_stream = DecryptionStream::<std::io::Error>::new(stream, key);
303
304    let result = crypt_stream.next().await.unwrap();
305    assert_eq!(result.unwrap(), Bytes::new());
306
307    let result = crypt_stream.next().await.unwrap();
308    assert_eq!(result.unwrap(), Bytes::new());
309
310    assert!(crypt_stream.next().await.is_none());
311}
312
313#[cfg(test)]
314#[tokio::test]
315async fn endtoend_empty_stream() {
316    use futures::StreamExt;
317    let key = Key::from_slice(b"abcdefghijklmnopqrstuvwxyz012345");
318    let stream = futures::stream::iter(vec![]);
319    let stream = EncryptionStream::<std::io::Error>::new(stream, key);
320    let mut stream = DecryptionStream::<std::io::Error>::new(stream, key);
321
322    let result = stream.next().await.unwrap();
323    assert_eq!(result.unwrap(), Bytes::new());
324
325    assert!(stream.next().await.is_none());
326}
327
328#[cfg(test)]
329#[tokio::test]
330async fn endtoend_single_stream() {
331    use futures::StreamExt;
332    let key = Key::from_slice(b"abcdefghijklmnopqrstuvwxyz012345");
333    let data: Bytes = "hello, world!".into();
334    let stream = futures::stream::iter(vec![Ok(data.clone())]);
335    let stream = EncryptionStream::<std::io::Error>::new(stream, key);
336    let mut stream = DecryptionStream::<std::io::Error>::new(stream, key);
337
338    let result = stream.next().await.unwrap();
339    assert_eq!(result.unwrap(), data);
340
341    assert!(stream.next().await.is_none());
342}
343
344#[cfg(test)]
345#[tokio::test]
346async fn endtoend_multi_stream() {
347    use futures::StreamExt;
348    let key = Key::from_slice(b"abcdefghijklmnopqrstuvwxyz012345");
349    let data1: Bytes = "hello, world!".into();
350    let data2: Bytes = "this is an example".into();
351    let stream = futures::stream::iter(vec![Ok(data1.clone()), Ok(data2.clone())]);
352    let stream = EncryptionStream::<std::io::Error>::new(stream, key);
353    let mut stream = DecryptionStream::<std::io::Error>::new(stream, key);
354
355    let result = stream.next().await.unwrap();
356    assert_eq!(result.unwrap(), data1);
357
358    let result = stream.next().await.unwrap();
359    assert_eq!(result.unwrap(), data2);
360
361    assert!(stream.next().await.is_none());
362}