ratrodlib/
buffed_stream.rs

1//! Buffed stream module.
2//!
3//! This module contains the `BuffedStream` type, which is a wrapper around a stream that provides
4//! buffering and encryption/decryption functionality.
5//!
6//! It is used to provide a bincode-centric stream that can be used to send and receive data
7//! in a more efficient manner.  In addition, the `AsyncRead` and `AsyncWrite` implementations
8//! are designed to "transparently" handle encryption and decryption of the data being sent
9//! and received (for the "pump" phase of the lifecycle).
10
11use std::{
12    pin::Pin,
13    task::{Context, Poll, ready},
14};
15
16use async_bincode::{
17    AsyncDestination,
18    tokio::{AsyncBincodeReader, AsyncBincodeWriter},
19};
20use futures::{Sink, Stream};
21use secrecy::ExposeSecret;
22use tokio::{
23    io::{AsyncRead, AsyncWrite, DuplexStream, ReadHalf, SimplexStream, WriteHalf},
24    net::{
25        TcpStream,
26        tcp::{OwnedReadHalf, OwnedWriteHalf},
27    },
28};
29
30use crate::{
31    base::{Constant, SharedSecret},
32    protocol::{ProtocolMessage, ProtocolMessageWrapper},
33    utils::{decrypt, encrypt},
34};
35
36// Macros.
37
38/// Macro to get a ref to `Pin<&mut BuffedStream<T>>` and return the inner `Pin<&mut AsyncBincodeStream<T>>`.
39macro_rules! pinned_inner {
40    ($self:ident) => {
41        Pin::new(&mut $self.inner)
42    };
43}
44
45/// Macro to take a `Pin<&mut BuffedStream<T>>` and return the inner `Pin<&mut AsyncBincodeStream<T>>`.
46macro_rules! take_pinned_inner {
47    ($self:ident) => {
48        Pin::new(&mut $self.get_mut().inner)
49    };
50}
51
52/// Macro to take the read half.
53macro_rules! take_pinned_inner_read {
54    ($self:ident) => {
55        Pin::new(&mut $self.get_mut().inner_read)
56    };
57}
58
59/// Macro to take the write half.
60macro_rules! take_pinned_inner_write {
61    ($self:ident) => {
62        Pin::new(&mut $self.get_mut().inner_write)
63    };
64}
65
66/// Macro to get a ref to `Pin<&mut BuffedStream<T>>` and return the decryption stream `Pin<&mut BufReader<SimplexStream>>`.
67macro_rules! pinned_read_stream {
68    ($self:ident) => {
69        Pin::new($self.read_stream.as_mut().unwrap())
70    };
71}
72
73/// Macro to take a `Pin<&mut BuffedStreamReadHalf<T>>` and return the decryption stream `Pin<&mut SimplexStream>`.
74macro_rules! take_pinned_read_stream {
75    ($self:ident) => {
76        Pin::new($self.get_mut().read_stream.as_mut().unwrap())
77    };
78}
79
80// Types.
81
82pub type BuffedTcpStream = BuffedStream<OwnedReadHalf, OwnedWriteHalf>;
83pub type BuffedDuplexStream = BuffedStream<ReadHalf<DuplexStream>, WriteHalf<DuplexStream>>;
84
85/// BuffedStream type.
86///
87/// This type is a wrapper around a stream that provides buffering and encryption/decryption functionality.
88/// It is used to provide a bincode-centric stream that can be used to send and receive data
89/// in a more efficient manner.
90///
91/// > This type is used to provide a bincode-centric stream that can be used to send and receive data
92/// > so it is inadvisable to use any other methods than the `push` and `pull` methods from the protocol
93/// > module.  Using `read` and `write` directly will bypass the normal logic, and should only be used when
94/// > you know what you are doing (most common use case is pumping data).
95///
96/// The `shared_secret` field is used to encrypt and decrypt data.
97/// The `read_stream` field is used to buffer data that has been decrypted.
98pub struct BuffedStream<R, W> {
99    /// The read half of the buffered stream
100    inner_read: BuffedStreamReadHalf<R>,
101    /// The write half of the buffered stream
102    inner_write: BuffedStreamWriteHalf<W>,
103}
104
105// Impl.
106
107impl<R, W> BuffedStream<R, W> {
108    /// Sets the shared secret for the stream, and enables encryption / decryption.
109    pub fn with_encryption(mut self, shared_secret: SharedSecret) -> Self {
110        let secret_clone = SharedSecret::init_with(|| *shared_secret.expose_secret());
111
112        self.inner_read.shared_secret = Some(secret_clone);
113        self.inner_read.read_stream = Some(SimplexStream::new_unsplit(Constant::BUFFER_SIZE));
114        self.inner_write.shared_secret = Some(shared_secret);
115
116        self
117    }
118
119    /// Splits the buffered stream into its read and write halves.
120    ///
121    /// This allows the read and write halves to be used independently, potentially
122    /// from different tasks or threads.
123    pub fn into_split(self) -> (BuffedStreamReadHalf<R>, BuffedStreamWriteHalf<W>) {
124        (self.inner_read, self.inner_write)
125    }
126}
127
128impl From<TcpStream> for BuffedStream<OwnedReadHalf, OwnedWriteHalf> {
129    fn from(stream: TcpStream) -> Self {
130        let (read, write) = stream.into_split();
131
132        Self {
133            inner_read: BuffedStreamReadHalf::new(read),
134            inner_write: BuffedStreamWriteHalf::new(write),
135        }
136    }
137}
138
139impl<T> From<T> for BuffedStream<ReadHalf<T>, WriteHalf<T>>
140where
141    T: AsyncRead + AsyncWrite + Unpin,
142{
143    fn from(stream: T) -> Self {
144        let (read, write) = tokio::io::split(stream);
145
146        Self {
147            inner_read: BuffedStreamReadHalf::new(read),
148            inner_write: BuffedStreamWriteHalf::new(write),
149        }
150    }
151}
152
153impl<R, W> BuffedStream<R, W>
154where
155    R: AsyncRead + Unpin,
156    W: AsyncWrite + Unpin,
157{
158    pub fn new(inner_read: R, inner_write: W) -> Self {
159        Self {
160            inner_read: BuffedStreamReadHalf::new(inner_read),
161            inner_write: BuffedStreamWriteHalf::new(inner_write),
162        }
163    }
164}
165
166impl<R> BuffedStream<R, OwnedWriteHalf> {
167    pub fn as_inner_tcp_write_ref(&self) -> &OwnedWriteHalf {
168        self.inner_write.inner.get_ref()
169    }
170}
171
172impl<W> BuffedStream<OwnedReadHalf, W> {
173    pub fn as_inner_tcp_read_ref(&self) -> &OwnedReadHalf {
174        self.inner_read.inner.get_ref()
175    }
176}
177
178// Trait impls.
179
180impl<R, W> Stream for BuffedStream<R, W>
181where
182    R: AsyncRead + Unpin,
183    W: Unpin,
184{
185    type Item = std::io::Result<ProtocolMessage>;
186
187    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
188        take_pinned_inner_read!(self).poll_next(cx)
189    }
190}
191
192impl<R, W> Sink<ProtocolMessage> for BuffedStream<R, W>
193where
194    R: Unpin,
195    W: AsyncWrite + Unpin,
196{
197    type Error = std::io::Error;
198
199    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
200        take_pinned_inner_write!(self).poll_ready(cx)
201    }
202
203    fn start_send(self: Pin<&mut Self>, item: ProtocolMessage) -> Result<(), Self::Error> {
204        take_pinned_inner_write!(self).start_send(item)
205    }
206
207    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
208        futures::Sink::<ProtocolMessage>::poll_flush(take_pinned_inner_write!(self), cx)
209    }
210
211    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
212        take_pinned_inner_write!(self).poll_close(cx)
213    }
214}
215
216impl<R, W> AsyncRead for BuffedStream<R, W>
217where
218    R: AsyncRead + Unpin,
219    W: Unpin,
220{
221    fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut tokio::io::ReadBuf<'_>) -> Poll<std::io::Result<()>> {
222        take_pinned_inner_read!(self).poll_read(cx, buf)
223    }
224}
225
226impl<R, W> AsyncWrite for BuffedStream<R, W>
227where
228    R: Unpin,
229    W: AsyncWrite + Unpin,
230{
231    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
232        take_pinned_inner_write!(self).poll_write(cx, buf)
233    }
234
235    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
236        AsyncWrite::poll_flush(take_pinned_inner_write!(self), cx)
237    }
238
239    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
240        take_pinned_inner_write!(self).poll_shutdown(cx)
241    }
242}
243
244// Split streams.
245
246pub struct BuffedStreamReadHalf<T> {
247    inner: AsyncBincodeReader<T, ProtocolMessageWrapper>,
248    shared_secret: Option<SharedSecret>,
249    read_stream: Option<SimplexStream>,
250}
251
252impl<T> BuffedStreamReadHalf<T>
253where
254    T: AsyncRead + Unpin,
255{
256    fn new(stream: T) -> Self {
257        Self {
258            inner: AsyncBincodeReader::from(stream),
259            shared_secret: None,
260            read_stream: None,
261        }
262    }
263}
264
265impl<T> Stream for BuffedStreamReadHalf<T>
266where
267    T: AsyncRead + Unpin,
268{
269    type Item = std::io::Result<ProtocolMessage>;
270
271    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
272        // Get an option to the shared secret.
273        let key = self.shared_secret.as_ref().map(|s| SharedSecret::init_with(|| *s.expose_secret()));
274
275        match take_pinned_inner!(self).poll_next(cx) {
276            Poll::Ready(Some(Ok(wrapper))) => match wrapper {
277                ProtocolMessageWrapper::Plain(message) => Poll::Ready(Some(Ok(message))),
278                ProtocolMessageWrapper::Encrypted { nonce, data } => {
279                    let Some(key) = key else {
280                        return Poll::Ready(Some(Err(std::io::Error::new(
281                            std::io::ErrorKind::InvalidData,
282                            "Received encrypted message without shared secret on this end",
283                        ))));
284                    };
285
286                    let Ok(decrypted_data) = decrypt(&key, &data, &nonce) else {
287                        return Poll::Ready(Some(Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "Decryption failed"))));
288                    };
289
290                    let Ok(message) = bincode::deserialize::<ProtocolMessage>(&decrypted_data) else {
291                        return Poll::Ready(Some(Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "Failed to deserialize decrypted data"))));
292                    };
293
294                    Poll::Ready(Some(Ok(message)))
295                }
296            },
297            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(std::io::Error::new(
298                std::io::ErrorKind::InvalidData,
299                format!("Error on bincode reading during stream next: {}", e),
300            )))),
301            Poll::Ready(None) => Poll::Ready(None),
302            Poll::Pending => Poll::Pending,
303        }
304    }
305}
306
307impl<T> AsyncRead for BuffedStreamReadHalf<T>
308where
309    T: AsyncRead + Unpin,
310{
311    fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut tokio::io::ReadBuf<'_>) -> Poll<std::io::Result<()>> {
312        // Read directly from the inner stream if there is no shared secret.
313        // This is the case where we are not encrypting the stream.
314        //
315        // Basically, this is an optimization that both `poll_read` and `poll_write` can use for
316        // the case where we are not encrypting the stream.
317        if self.shared_secret.is_none() {
318            return Pin::new(self.inner.get_mut()).poll_read(cx, buf);
319        }
320
321        // Use the "self" reader to get the next packet (and perform any needed decryption).
322        // Performance optimization opportunity: We could loop on poll_next here until we either 
323        // get Poll::Pending, or we no longer have space in the read_stream buffer.
324        let result = self.as_mut().poll_next(cx);
325
326        match result {
327            Poll::Ready(Some(Ok(message))) => {
328                let ProtocolMessage::Data(data) = message else {
329                    return Poll::Ready(Err(std::io::Error::new(
330                        std::io::ErrorKind::InvalidData,
331                        "Received non-data message during `poll_read`, which shouldn't happen",
332                    )));
333                };
334
335                // We have the data, so we can write it to the `read_stream`.
336                let written = ready!(pinned_read_stream!(self).poll_write(cx, &data)?);
337
338                // Fail if the interim buffer is too small.
339                if written < data.len() {
340                    return Poll::Ready(Err(std::io::Error::new(
341                        std::io::ErrorKind::InvalidData,
342                        "Decryption stream buffer overflow (shouldn't happen unless there is a mismatched buffer size between client and server)",
343                    )));
344                }
345
346                // Flush the `read_stream` to ensure that the data is available for reading (see below).
347                ready!(pinned_read_stream!(self).poll_flush(cx)?);
348            }
349            Poll::Ready(Some(Err(e))) => {
350                // This is the case where we have a bincode error, so we should return the error.
351
352                return Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Error on bincode reading during pump: {}", e))));
353            }
354            Poll::Ready(None) => {
355                // If we read no data from the inner buffer, then we are "shutdown",
356                // so we should shutdown the write side of the `decryption_stream`, and
357                // return the final poll result (bottom of function).
358
359                ready!(pinned_read_stream!(self).poll_shutdown(cx)?);
360            }
361            Poll::Pending => {
362                // If we are pending, then we should pass through to the underlying decryption stream (so do nothing here).
363                // The underlying decryption stream will be properly shutdown in the case of a shutdown on the inner stream.
364            }
365        }
366
367        // At this point, if there was data to decrypt, we have decrypted it; if not, we may have some data in the
368        // decrypted stream, so we just offload onto its `poll_read` method.
369        take_pinned_read_stream!(self).poll_read(cx, buf)
370    }
371}
372
373pub struct BuffedStreamWriteHalf<T> {
374    inner: AsyncBincodeWriter<T, ProtocolMessageWrapper, AsyncDestination>,
375    shared_secret: Option<SharedSecret>,
376}
377
378impl<T> BuffedStreamWriteHalf<T>
379where
380    T: AsyncWrite + Unpin,
381{
382    fn new(stream: T) -> Self {
383        Self {
384            inner: AsyncBincodeWriter::from(stream).for_async(),
385            shared_secret: None,
386        }
387    }
388}
389
390impl<T> Sink<ProtocolMessage> for BuffedStreamWriteHalf<T>
391where
392    T: AsyncWrite + Unpin,
393{
394    type Error = std::io::Error;
395
396    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
397        take_pinned_inner!(self)
398            .poll_ready(cx)
399            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Failed to flush inner stream: {}", e)))
400    }
401
402    fn start_send(self: Pin<&mut Self>, item: ProtocolMessage) -> Result<(), Self::Error> {
403        if let Some(key) = self.shared_secret.as_ref() {
404            let encrypted_data = encrypt(
405                key,
406                &bincode::serialize(&item).map_err(|_| std::io::Error::new(std::io::ErrorKind::InvalidData, "Failed to serialize message"))?,
407            )
408            .map_err(|_| std::io::Error::new(std::io::ErrorKind::InvalidData, "Encryption failed"))?;
409
410            let message = ProtocolMessageWrapper::Encrypted {
411                nonce: encrypted_data.nonce,
412                data: encrypted_data.data,
413            };
414
415            take_pinned_inner!(self)
416                .start_send(message)
417                .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Failed to write encrypted packet: {}", e)))?;
418
419            return Ok(());
420        }
421
422        take_pinned_inner!(self)
423            .start_send(ProtocolMessageWrapper::Plain(item))
424            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Failed to write plain packet: {}", e)))
425    }
426
427    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
428        futures::Sink::<ProtocolMessageWrapper>::poll_flush(take_pinned_inner!(self), cx)
429            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Failed to flush inner stream: {}", e)))
430    }
431
432    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
433        take_pinned_inner!(self)
434            .poll_close(cx)
435            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Failed to shutdown inner stream: {}", e)))
436    }
437}
438
439impl<T> AsyncWrite for BuffedStreamWriteHalf<T>
440where
441    T: AsyncWrite + Unpin,
442{
443    fn poll_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<std::io::Result<usize>> {
444        // If there is no shared secret, then we are not encrypting the stream,
445        // so we can just write directly to the inner stream.
446        //
447        // Basically, this is an optimization that both `poll_read` and `poll_write` can use for
448        // the case where we are not encrypting the stream.
449        if self.shared_secret.is_none() {
450            return Pin::new(self.inner.get_mut()).poll_write(cx, buf);
451        }
452
453        // First, we need to pare down the data to the maximum size of the encrypted packet, if needed.
454        let max_size = Constant::BUFFER_SIZE - Constant::ENCRYPTION_OVERHEAD;
455        let amt = std::cmp::min(buf.len(), max_size);
456        let buf = &buf[..amt];
457
458        let message = ProtocolMessage::Data(buf.to_vec());
459
460        // Write the encrypted data to the "self" `start_send`, which performs any needed encryption logic.
461        self.as_mut()
462            .start_send(message)
463            .map_err(|_| std::io::Error::new(std::io::ErrorKind::InvalidData, "Failed to write encrypted packet"))?;
464
465        // Need to report the amount of data that was written _from the input_, not the _actual_ amount written to the inner stream.
466        // This allows the caller to know how much of _their_ data was written, which is all that matters.
467        Poll::Ready(Ok(buf.len()))
468    }
469
470    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
471        pinned_inner!(self)
472            .poll_flush(cx)
473            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Failed to flush inner stream: {}", e)))
474    }
475
476    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
477        pinned_inner!(self)
478            .poll_close(cx)
479            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Failed to shutdown inner stream: {}", e)))
480    }
481}
482
483// Tests.
484
485#[cfg(test)]
486mod tests {
487    use futures::future::join_all;
488    use tokio::io::{AsyncReadExt, AsyncWriteExt};
489
490    use crate::utils::tests::{generate_test_duplex, generate_test_duplex_with_encryption};
491
492    #[tokio::test]
493    async fn test_unencrypted_buffed_stream() {
494        let (mut client, mut server) = generate_test_duplex();
495
496        let data = b"Hello, world!";
497
498        client.write_all(data).await.unwrap();
499        client.shutdown().await.unwrap();
500
501        let mut received = Vec::new();
502        server.read_to_end(&mut received).await.unwrap();
503
504        assert_eq!(data, &received[..]);
505    }
506
507    #[tokio::test]
508    async fn test_e2e_encrypted_buffed_stream() {
509        let (mut client, mut server) = generate_test_duplex_with_encryption();
510
511        let data = b"Hello, world!";
512
513        client.write_all(data).await.unwrap();
514        client.shutdown().await.unwrap();
515
516        let mut received = Vec::new();
517        server.read_to_end(&mut received).await.unwrap();
518
519        assert_eq!(data, &received[..]);
520    }
521
522    #[tokio::test]
523    async fn test_e2e_encrypted_buffed_stream_with_multiple_packets() {
524        let (mut client, mut server) = generate_test_duplex_with_encryption();
525
526        let data1 = b"Hello, world!";
527        let data2 = b"Hello, world!";
528
529        client.write_all(data1).await.unwrap();
530        client.write_all(data2).await.unwrap();
531        client.shutdown().await.unwrap();
532
533        let mut received = Vec::new();
534        server.read_to_end(&mut received).await.unwrap();
535
536        assert_eq!(data1.len() + data2.len(), received.len());
537    }
538
539    #[tokio::test]
540    async fn test_e2e_encrypted_buffed_stream_with_large_data() {
541        let (mut client, mut server) = generate_test_duplex_with_encryption();
542
543        let data = b"Hello, world!";
544        let data = data.repeat(10000);
545
546        let data_clone = data.clone();
547
548        let write_task = tokio::spawn(async move {
549            client.write_all(&data_clone).await.unwrap();
550            client.shutdown().await.unwrap();
551        });
552
553        let read_task = tokio::spawn(async move {
554            let mut received = Vec::new();
555            server.read_to_end(&mut received).await.unwrap();
556            assert_eq!(data.len(), received.len());
557        });
558
559        join_all([write_task, read_task]).await.into_iter().collect::<Result<Vec<_>, _>>().unwrap();
560    }
561}