ratrodlib/
buffed_stream.rs

1//! Buffed stream module.
2//!
3//! This module contains the `BuffedStream` type, which is a wrapper around a stream that provides
4//! buffering and encryption/decryption functionality.
5//!
6//! It is used to provide a bincode-centric stream that can be used to send and receive data
7//! in a more efficient manner.  In addition, the `AsyncRead` and `AsyncWrite` implementations
8//! are designed to "transparently" handle encryption and decryption of the data being sent
9//! and received (for the "pump" phase of the lifecycle).
10
11use std::{
12    pin::Pin,
13    task::{Context, Poll, ready},
14};
15
16use async_bincode::{
17    AsyncDestination,
18    tokio::{AsyncBincodeReader, AsyncBincodeWriter},
19};
20use futures::{Sink, Stream};
21use secrecy::ExposeSecret;
22use tokio::{
23    io::{AsyncRead, AsyncWrite, DuplexStream, ReadHalf, SimplexStream, WriteHalf},
24    net::{
25        TcpStream,
26        tcp::{OwnedReadHalf, OwnedWriteHalf},
27    },
28};
29
30use crate::{
31    base::{Constant, SharedSecret},
32    protocol::{ProtocolMessage, ProtocolMessageWrapper},
33    utils::{decrypt, encrypt},
34};
35
36// Macros.
37
38/// Macro to get a ref to `Pin<&mut BuffedStream<T>>` and return the inner `Pin<&mut AsyncBincodeStream<T>>`.
39macro_rules! pinned_inner {
40    ($self:ident) => {
41        Pin::new(&mut $self.inner)
42    };
43}
44
45/// Macro to take a `Pin<&mut BuffedStream<T>>` and return the inner `Pin<&mut AsyncBincodeStream<T>>`.
46macro_rules! take_pinned_inner {
47    ($self:ident) => {
48        Pin::new(&mut $self.get_mut().inner)
49    };
50}
51
52/// Macro to take the read half.
53macro_rules! take_pinned_inner_read {
54    ($self:ident) => {
55        Pin::new(&mut $self.get_mut().inner_read)
56    };
57}
58
59/// Macro to take the write half.
60macro_rules! take_pinned_inner_write {
61    ($self:ident) => {
62        Pin::new(&mut $self.get_mut().inner_write)
63    };
64}
65
66/// Macro to get a ref to `Pin<&mut BuffedStream<T>>` and return the decryption stream `Pin<&mut BufReader<SimplexStream>>`.
67macro_rules! pinned_read_stream {
68    ($self:ident) => {
69        Pin::new($self.read_stream.as_mut().unwrap())
70    };
71}
72
73/// Macro to take a `Pin<&mut BuffedStream<T>>` and return the decryption stream `Pin<&mut BufReader<SimplexStream>>`.
74macro_rules! take_pinned_read_stream {
75    ($self:ident) => {
76        Pin::new($self.get_mut().read_stream.as_mut().unwrap())
77    };
78}
79
80// Types.
81
82pub type BuffedTcpStream = BuffedStream<OwnedReadHalf, OwnedWriteHalf>;
83pub type BuffedDuplexStream = BuffedStream<ReadHalf<DuplexStream>, WriteHalf<DuplexStream>>;
84
85/// BuffedStream type.
86///
87/// This type is a wrapper around a stream that provides buffering and encryption/decryption functionality.
88/// It is used to provide a bincode-centric stream that can be used to send and receive data
89/// in a more efficient manner.
90///
91/// > This type is used to provide a bincode-centric stream that can be used to send and receive data
92/// > so it is inadvisable to use any other methods than the `push` and `pull` methods from the protocol
93/// > module.  Using `read` and `write` directly will bypass the normal logic, and should only be used when
94/// > you know what you are doing (most common use case is pumping data).
95///
96/// The `shared_secret` field is used to encrypt and decrypt data.
97/// The `read_stream` field is used to buffer data that has been decrypted.
98pub struct BuffedStream<R, W> {
99    inner_read: BuffedStreamReadHalf<R>,
100    inner_write: BuffedStreamWriteHalf<W>,
101}
102
103// Impl.
104
105impl<R, W> BuffedStream<R, W> {
106    /// Sets the shared secret for the stream, and enables encryption / decryption.
107    pub fn with_encryption(mut self, shared_secret: SharedSecret) -> Self {
108        let secret_clone = SharedSecret::init_with(|| *shared_secret.expose_secret());
109
110        self.inner_read.shared_secret = Some(secret_clone);
111        self.inner_read.read_stream = Some(SimplexStream::new_unsplit(Constant::BUFFER_SIZE));
112        self.inner_write.shared_secret = Some(shared_secret);
113
114        self
115    }
116
117    pub fn into_split(self) -> (BuffedStreamReadHalf<R>, BuffedStreamWriteHalf<W>) {
118        (self.inner_read, self.inner_write)
119    }
120}
121
122impl From<TcpStream> for BuffedStream<OwnedReadHalf, OwnedWriteHalf> {
123    fn from(stream: TcpStream) -> Self {
124        let (read, write) = stream.into_split();
125
126        Self {
127            inner_read: BuffedStreamReadHalf::new(read),
128            inner_write: BuffedStreamWriteHalf::new(write),
129        }
130    }
131}
132
133impl<T> From<T> for BuffedStream<ReadHalf<T>, WriteHalf<T>>
134where
135    T: AsyncRead + AsyncWrite + Unpin,
136{
137    fn from(stream: T) -> Self {
138        let (read, write) = tokio::io::split(stream);
139
140        Self {
141            inner_read: BuffedStreamReadHalf::new(read),
142            inner_write: BuffedStreamWriteHalf::new(write),
143        }
144    }
145}
146
147impl<R, W> BuffedStream<R, W>
148where
149    R: AsyncRead + Unpin,
150    W: AsyncWrite + Unpin,
151{
152    pub fn new(inner_read: R, inner_write: W) -> Self {
153        Self {
154            inner_read: BuffedStreamReadHalf::new(inner_read),
155            inner_write: BuffedStreamWriteHalf::new(inner_write),
156        }
157    }
158}
159
160impl<R> BuffedStream<R, OwnedWriteHalf> {
161    pub fn as_inner_tcp_write_ref(&self) -> &OwnedWriteHalf {
162        self.inner_write.inner.get_ref()
163    }
164}
165
166impl<W> BuffedStream<OwnedReadHalf, W> {
167    pub fn as_inner_tcp_read_ref(&self) -> &OwnedReadHalf {
168        self.inner_read.inner.get_ref()
169    }
170}
171
172// Trait impls.
173
174impl<R, W> Stream for BuffedStream<R, W>
175where
176    R: AsyncRead + Unpin,
177    W: Unpin,
178{
179    type Item = std::io::Result<ProtocolMessage>;
180
181    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
182        take_pinned_inner_read!(self).poll_next(cx)
183    }
184}
185
186impl<R, W> Sink<ProtocolMessage> for BuffedStream<R, W>
187where
188    R: Unpin,
189    W: AsyncWrite + Unpin,
190{
191    type Error = std::io::Error;
192
193    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
194        take_pinned_inner_write!(self).poll_ready(cx)
195    }
196
197    fn start_send(self: Pin<&mut Self>, item: ProtocolMessage) -> Result<(), Self::Error> {
198        take_pinned_inner_write!(self).start_send(item)
199    }
200
201    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
202        futures::Sink::<ProtocolMessage>::poll_flush(take_pinned_inner_write!(self), cx)
203    }
204
205    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
206        take_pinned_inner_write!(self).poll_close(cx)
207    }
208}
209
210impl<R, W> AsyncRead for BuffedStream<R, W>
211where
212    R: AsyncRead + Unpin,
213    W: Unpin,
214{
215    fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut tokio::io::ReadBuf<'_>) -> Poll<std::io::Result<()>> {
216        take_pinned_inner_read!(self).poll_read(cx, buf)
217    }
218}
219
220impl<R, W> AsyncWrite for BuffedStream<R, W>
221where
222    R: Unpin,
223    W: AsyncWrite + Unpin,
224{
225    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
226        take_pinned_inner_write!(self).poll_write(cx, buf)
227    }
228
229    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
230        AsyncWrite::poll_flush(take_pinned_inner_write!(self), cx)
231    }
232
233    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
234        take_pinned_inner_write!(self).poll_shutdown(cx)
235    }
236}
237
238// Split streams.
239
240pub struct BuffedStreamReadHalf<T> {
241    inner: AsyncBincodeReader<T, ProtocolMessageWrapper>,
242    shared_secret: Option<SharedSecret>,
243    read_stream: Option<SimplexStream>,
244}
245
246impl<T> BuffedStreamReadHalf<T>
247where
248    T: AsyncRead + Unpin,
249{
250    fn new(stream: T) -> Self {
251        Self {
252            inner: AsyncBincodeReader::from(stream),
253            shared_secret: None,
254            read_stream: None,
255        }
256    }
257}
258
259impl<T> Stream for BuffedStreamReadHalf<T>
260where
261    T: AsyncRead + Unpin,
262{
263    type Item = std::io::Result<ProtocolMessage>;
264
265    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
266        // Get an option to the shared secret.
267        let key = self.shared_secret.as_ref().map(|s| SharedSecret::init_with(|| *s.expose_secret()));
268
269        match take_pinned_inner!(self).poll_next(cx) {
270            Poll::Ready(Some(Ok(wrapper))) => match wrapper {
271                ProtocolMessageWrapper::Plain(message) => Poll::Ready(Some(Ok(message))),
272                ProtocolMessageWrapper::Encrypted { nonce, data } => {
273                    let Some(key) = key else {
274                        return Poll::Ready(Some(Err(std::io::Error::new(
275                            std::io::ErrorKind::InvalidData,
276                            "Received encrypted message without shared secret on this end",
277                        ))));
278                    };
279
280                    let Ok(decrypted_data) = decrypt(&key, &data, &nonce) else {
281                        return Poll::Ready(Some(Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "Decryption failed"))));
282                    };
283
284                    let Ok(message) = bincode::deserialize::<ProtocolMessage>(&decrypted_data) else {
285                        return Poll::Ready(Some(Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "Failed to deserialize decrypted data"))));
286                    };
287
288                    Poll::Ready(Some(Ok(message)))
289                }
290            },
291            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(std::io::Error::new(
292                std::io::ErrorKind::InvalidData,
293                format!("Error on bincode reading during stream next: {}", e),
294            )))),
295            Poll::Ready(None) => Poll::Ready(None),
296            Poll::Pending => Poll::Pending,
297        }
298    }
299}
300
301impl<T> AsyncRead for BuffedStreamReadHalf<T>
302where
303    T: AsyncRead + Unpin,
304{
305    fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut tokio::io::ReadBuf<'_>) -> Poll<std::io::Result<()>> {
306        // Read directly from the inner stream if there is no shared secret.
307        // This is the case where we are not encrypting the stream.
308        //
309        // Basically, this is an optimization that both `poll_read` and `poll_write` can use for
310        // the case where we are not encrypting the stream.
311        if self.shared_secret.is_none() {
312            return Pin::new(self.inner.get_mut()).poll_read(cx, buf);
313        }
314
315        // Use the "self" reader to get the next packet (and perform any needed decryption).
316        // TODO: We could actually loop on poll next here until we either get a pending, or we no longer have space in the
317        // read stream.
318        let result = self.as_mut().poll_next(cx);
319
320        match result {
321            Poll::Ready(Some(Ok(message))) => {
322                let ProtocolMessage::Data(data) = message else {
323                    return Poll::Ready(Err(std::io::Error::new(
324                        std::io::ErrorKind::InvalidData,
325                        "Received non-data message during `poll_read`, which shouldn't happen",
326                    )));
327                };
328
329                // We have the data, so we can write it to the `read_stream`.
330                let written = ready!(pinned_read_stream!(self).poll_write(cx, &data)?);
331
332                // Fail if the interim buffer is too small.
333                if written < data.len() {
334                    return Poll::Ready(Err(std::io::Error::new(
335                        std::io::ErrorKind::InvalidData,
336                        "Decryption stream buffer overflow (shouldn't happen unless there is a mismatched buffer size between client and server)",
337                    )));
338                }
339
340                // Flush the `read_stream` to ensure that the data is available for reading (see below).
341                ready!(pinned_read_stream!(self).poll_flush(cx)?);
342            }
343            Poll::Ready(Some(Err(e))) => {
344                // This is the case where we have a bincode error, so we should return the error.
345
346                return Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Error on bincode reading during pump: {}", e))));
347            }
348            Poll::Ready(None) => {
349                // If we read no data from the inner buffer, then we are "shutdown",
350                // so we should shutdown the write side of the `decryption_stream`, and
351                // return the final poll result (bottom of function).
352
353                ready!(pinned_read_stream!(self).poll_shutdown(cx)?);
354            }
355            Poll::Pending => {
356                // If we are pending, then we should pass through to the underlying decryption stream (so do nothing here).
357                // The underlying decryption stream will be properly shutdown in the case of a shutdown on the inner stream.
358            }
359        }
360
361        // At this point, if there was data to decrypt, we have decrypted it; if not, we may have some data in the
362        // decrypted stream, so we just offload onto its `poll_read` method.
363        take_pinned_read_stream!(self).poll_read(cx, buf)
364    }
365}
366
367pub struct BuffedStreamWriteHalf<T> {
368    inner: AsyncBincodeWriter<T, ProtocolMessageWrapper, AsyncDestination>,
369    shared_secret: Option<SharedSecret>,
370}
371
372impl<T> BuffedStreamWriteHalf<T>
373where
374    T: AsyncWrite + Unpin,
375{
376    fn new(stream: T) -> Self {
377        Self {
378            inner: AsyncBincodeWriter::from(stream).for_async(),
379            shared_secret: None,
380        }
381    }
382}
383
384impl<T> Sink<ProtocolMessage> for BuffedStreamWriteHalf<T>
385where
386    T: AsyncWrite + Unpin,
387{
388    type Error = std::io::Error;
389
390    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
391        take_pinned_inner!(self)
392            .poll_ready(cx)
393            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Failed to flush inner stream: {}", e)))
394    }
395
396    fn start_send(self: Pin<&mut Self>, item: ProtocolMessage) -> Result<(), Self::Error> {
397        if let Some(key) = self.shared_secret.as_ref() {
398            let encrypted_data = encrypt(
399                key,
400                &bincode::serialize(&item).map_err(|_| std::io::Error::new(std::io::ErrorKind::InvalidData, "Failed to serialize message"))?,
401            )
402            .map_err(|_| std::io::Error::new(std::io::ErrorKind::InvalidData, "Encryption failed"))?;
403
404            let message = ProtocolMessageWrapper::Encrypted {
405                nonce: encrypted_data.nonce,
406                data: encrypted_data.data,
407            };
408
409            take_pinned_inner!(self)
410                .start_send(message)
411                .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Failed to write encrypted packet: {}", e)))?;
412
413            return Ok(());
414        }
415
416        take_pinned_inner!(self)
417            .start_send(ProtocolMessageWrapper::Plain(item))
418            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Failed to write plain packet: {}", e)))
419    }
420
421    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
422        futures::Sink::<ProtocolMessageWrapper>::poll_flush(take_pinned_inner!(self), cx)
423            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Failed to flush inner stream: {}", e)))
424    }
425
426    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
427        take_pinned_inner!(self)
428            .poll_close(cx)
429            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Failed to shutdown inner stream: {}", e)))
430    }
431}
432
433impl<T> AsyncWrite for BuffedStreamWriteHalf<T>
434where
435    T: AsyncWrite + Unpin,
436{
437    fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
438        // If there is no shared secret, then we are not encrypting the stream,
439        // so we can just write directly to the inner stream.
440        //
441        // Basically, this is an optimization that both `poll_read` and `poll_write` can use for
442        // the case where we are not encrypting the stream.
443        if self.shared_secret.is_none() {
444            return Pin::new(self.inner.get_mut()).poll_write(cx, buf);
445        }
446
447        // First, we need to pare down the data to the maximum size of the encrypted packet, if needed.
448        let max_size = Constant::BUFFER_SIZE - Constant::ENCRYPTION_OVERHEAD;
449        let amt = std::cmp::min(buf.len(), max_size);
450        let buf = &buf[..amt];
451
452        let message = ProtocolMessage::Data(buf.to_vec());
453
454        // Write the encrypted data to the "self" `start_send`, which performs any needed encryption logic.
455        self.as_mut()
456            .start_send(message)
457            .map_err(|_| std::io::Error::new(std::io::ErrorKind::InvalidData, "Failed to write encrypted packet"))?;
458
459        // Need to report the amount of data that was written _from the input_, not the _actual_ amount written to the inner stream.
460        // This allows the caller to know how much of _their_ data was written, which is all that matters.
461        Poll::Ready(Ok(buf.len()))
462    }
463
464    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
465        pinned_inner!(self)
466            .poll_flush(cx)
467            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Failed to flush inner stream: {}", e)))
468    }
469
470    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
471        pinned_inner!(self)
472            .poll_close(cx)
473            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Failed to shutdown inner stream: {}", e)))
474    }
475}
476
477// Tests.
478
479#[cfg(test)]
480mod tests {
481    use futures::future::join_all;
482    use tokio::io::{AsyncReadExt, AsyncWriteExt};
483
484    use crate::utils::tests::{generate_test_duplex, generate_test_duplex_with_encryption};
485
486    #[tokio::test]
487    async fn test_unencrypted_buffed_stream() {
488        let (mut client, mut server) = generate_test_duplex();
489
490        let data = b"Hello, world!";
491
492        client.write_all(data).await.unwrap();
493        client.shutdown().await.unwrap();
494
495        let mut received = Vec::new();
496        server.read_to_end(&mut received).await.unwrap();
497
498        assert_eq!(data, &received[..]);
499    }
500
501    #[tokio::test]
502    async fn test_e2e_encrypted_buffed_stream() {
503        let (mut client, mut server) = generate_test_duplex_with_encryption();
504
505        let data = b"Hello, world!";
506
507        client.write_all(data).await.unwrap();
508        client.shutdown().await.unwrap();
509
510        let mut received = Vec::new();
511        server.read_to_end(&mut received).await.unwrap();
512
513        assert_eq!(data, &received[..]);
514    }
515
516    #[tokio::test]
517    async fn test_e2e_encrypted_buffed_stream_with_multiple_packets() {
518        let (mut client, mut server) = generate_test_duplex_with_encryption();
519
520        let data1 = b"Hello, world!";
521        let data2 = b"Hello, world!";
522
523        client.write_all(data1).await.unwrap();
524        client.write_all(data2).await.unwrap();
525        client.shutdown().await.unwrap();
526
527        let mut received = Vec::new();
528        server.read_to_end(&mut received).await.unwrap();
529
530        assert_eq!(data1.len() + data2.len(), received.len());
531    }
532
533    #[tokio::test]
534    async fn test_e2e_encrypted_buffed_stream_with_large_data() {
535        let (mut client, mut server) = generate_test_duplex_with_encryption();
536
537        let data = b"Hello, world!";
538        let data = data.repeat(10000);
539
540        let data_clone = data.clone();
541
542        let write_task = tokio::spawn(async move {
543            client.write_all(&data_clone).await.unwrap();
544            client.shutdown().await.unwrap();
545        });
546
547        let read_task = tokio::spawn(async move {
548            let mut received = Vec::new();
549            server.read_to_end(&mut received).await.unwrap();
550            assert_eq!(data.len(), received.len());
551        });
552
553        join_all([write_task, read_task]).await.into_iter().collect::<Result<Vec<_>, _>>().unwrap();
554    }
555}