ratrodlib/
buffed_stream.rs

1//! Buffed stream module.
2//!
3//! This module contains the `BuffedStream` type, which is a wrapper around a stream that provides
4//! buffering and encryption/decryption functionality.
5//!
6//! It is used to provide a bincode-centric stream that can be used to send and receive data
7//! in a more efficient manner.  In addition, the `AsyncRead` and `AsyncWrite` implementations
8//! are designed to "transparently" handle encryption and decryption of the data being sent
9//! and received (for the "pump" phase of the lifecycle).
10
11use anyhow::{Context as _, anyhow};
12use bincode::Encode;
13use bytes::{BufMut, Bytes, BytesMut};
14use secrecy::ExposeSecret;
15use tokio::{
16    io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, DuplexStream, ReadHalf, WriteHalf},
17    net::{
18        TcpStream,
19        tcp::{OwnedReadHalf, OwnedWriteHalf},
20    },
21};
22use tracing::{error, warn};
23
24use crate::{
25    base::{Constant, Res, SharedSecret, Void},
26    protocol::{BincodeReceive, BincodeSend, ProtocolMessage, ProtocolMessageGuard, ProtocolMessageGuardBuilder},
27    utils::{decrypt_in_place, encrypt_into},
28};
29
30// Traits.
31
32/// A trait for splitting a [`BuffedStream`] into its read and write halves.
33pub trait BincodeSplit {
34    type ReadHalf: BincodeReceive;
35    type WriteHalf: BincodeSend;
36
37    /// Takes and splits the buffered stream into its read and write halves.
38    ///
39    /// This allows the read and write halves to be used independently, potentially
40    /// from different tasks or threads.
41    fn into_split(self) -> (Self::ReadHalf, Self::WriteHalf);
42
43    /// Splits the buffered stream into mutably borrowed read and write halves.
44    ///
45    /// This allows the read and write halves to be used independently, potentially
46    /// from different tasks or threads.
47    fn split(&mut self) -> (&mut Self::ReadHalf, &mut Self::WriteHalf);
48}
49
50// Types.
51
52/// A type alias for a buffed [`TcpStream`].
53pub type BuffedTcpStream = BuffedStream<OwnedReadHalf, OwnedWriteHalf>;
54/// A type alias for a buffed [`DuplexStream`].
55pub type BuffedDuplexStream = BuffedStream<ReadHalf<DuplexStream>, WriteHalf<DuplexStream>>;
56
57/// BuffedStream type.
58///
59/// This type is a wrapper around a stream that provides buffering and encryption/decryption functionality.
60/// It is used to provide a bincode-centric stream that can be used to send and receive data
61/// in a more efficient manner.  In order to make _usual_ future splitting more ergonomic, this type
62/// is designed to be a wrapper around the split halves.
63///
64/// > This type is used to provide a bincode-centric stream that can be used to send and receive data
65/// > so it is inadvisable to use any other methods than the `push` and `pull` methods from the protocol
66/// > module.
67pub struct BuffedStream<R, W> {
68    /// The read half of the buffered stream
69    inner_read: BuffedStreamReadHalf<R>,
70    /// The write half of the buffered stream
71    inner_write: BuffedStreamWriteHalf<W>,
72}
73
74// Impl.
75
76impl<R, W> BuffedStream<R, W> {
77    /// Sets the shared secret for the stream, and enables encryption / decryption.
78    pub fn with_encryption(mut self, shared_secret: SharedSecret) -> Self {
79        let secret_clone = SharedSecret::init_with(|| *shared_secret.expose_secret());
80
81        self.inner_read.shared_secret = Some(secret_clone);
82        self.inner_write.shared_secret = Some(shared_secret);
83
84        self
85    }
86}
87
88impl<R, W> BincodeSplit for BuffedStream<R, W>
89where
90    R: AsyncRead + Unpin,
91    W: AsyncWrite + Unpin,
92{
93    type ReadHalf = BuffedStreamReadHalf<R>;
94    type WriteHalf = BuffedStreamWriteHalf<W>;
95
96    fn into_split(self) -> (Self::ReadHalf, Self::WriteHalf) {
97        (self.inner_read, self.inner_write)
98    }
99
100    fn split(&mut self) -> (&mut Self::ReadHalf, &mut Self::WriteHalf) {
101        (&mut self.inner_read, &mut self.inner_write)
102    }
103}
104
105impl From<TcpStream> for BuffedStream<OwnedReadHalf, OwnedWriteHalf> {
106    fn from(stream: TcpStream) -> Self {
107        let (read, write) = stream.into_split();
108
109        Self {
110            inner_read: BuffedStreamReadHalf::new(read),
111            inner_write: BuffedStreamWriteHalf::new(write),
112        }
113    }
114}
115
116impl<T> From<T> for BuffedStream<ReadHalf<T>, WriteHalf<T>>
117where
118    T: AsyncRead + AsyncWrite + Unpin,
119{
120    fn from(stream: T) -> Self {
121        let (read, write) = tokio::io::split(stream);
122
123        Self {
124            inner_read: BuffedStreamReadHalf::new(read),
125            inner_write: BuffedStreamWriteHalf::new(write),
126        }
127    }
128}
129
130impl<R, W> BuffedStream<R, W>
131where
132    R: AsyncRead + Unpin,
133    W: AsyncWrite + Unpin,
134{
135    pub fn from_splits(inner_read: R, inner_write: W) -> Self {
136        Self {
137            inner_read: BuffedStreamReadHalf::new(inner_read),
138            inner_write: BuffedStreamWriteHalf::new(inner_write),
139        }
140    }
141}
142
143impl<R> BuffedStream<R, OwnedWriteHalf> {
144    pub fn as_inner_tcp_write_ref(&self) -> &OwnedWriteHalf {
145        &self.inner_write.inner
146    }
147
148    pub fn as_inner_tcp_write_mut(&mut self) -> &mut OwnedWriteHalf {
149        &mut self.inner_write.inner
150    }
151}
152
153impl<W> BuffedStream<OwnedReadHalf, W> {
154    pub fn as_inner_tcp_read_ref(&self) -> &OwnedReadHalf {
155        &self.inner_read.inner
156    }
157
158    pub fn as_inner_tcp_read_mut(&mut self) -> &mut OwnedReadHalf {
159        &mut self.inner_read.inner
160    }
161}
162
163impl BuffedStream<OwnedReadHalf, OwnedWriteHalf> {
164    pub fn take(self) -> Res<TcpStream> {
165        let read = self.inner_read.take();
166        let write = self.inner_write.take();
167
168        read.reunite(write).context("Failed to reunite read and write halves")
169    }
170}
171
172// Trait impls.
173
174impl<R, W> BincodeSend for BuffedStream<R, W>
175where
176    R: Unpin,
177    W: AsyncWrite + Unpin,
178{
179    async fn push<E>(&mut self, message: E) -> Void
180    where
181        E: Encode,
182    {
183        self.inner_write.push(message).await?;
184
185        Ok(())
186    }
187
188    async fn close(&mut self) -> Void {
189        self.inner_write.close().await?;
190
191        Ok(())
192    }
193}
194
195impl<R, W> BincodeReceive for BuffedStream<R, W>
196where
197    R: AsyncRead + Unpin,
198    W: Unpin,
199{
200    async fn pull(&mut self) -> Res<ProtocolMessageGuard> {
201        self.inner_read.pull().await
202    }
203}
204
205// Split streams.
206
207/// A type for the read half of a buffed stream.
208pub struct BuffedStreamReadHalf<T> {
209    inner: T,
210    shared_secret: Option<SharedSecret>,
211    buffer: BytesMut,
212    decryption_buffer: BytesMut,
213}
214
215impl<T> BuffedStreamReadHalf<T>
216where
217    T: AsyncRead + Unpin,
218{
219    /// Creates a new `BuffedStreamReadHalf` from the given [`AsyncRead`].
220    fn new(async_read: T) -> Self {
221        Self {
222            inner: async_read,
223            shared_secret: None,
224            buffer: BytesMut::with_capacity(2 * Constant::BUFFER_SIZE),
225            decryption_buffer: BytesMut::with_capacity(2 * Constant::BUFFER_SIZE),
226        }
227    }
228
229    /// Takes the inner stream and returns it.
230    fn take(self) -> T {
231        if !self.buffer.is_empty() {
232            warn!("Buffer was not empty when taking the stream");
233        }
234
235        self.inner
236    }
237}
238
239impl<T> BincodeReceive for BuffedStreamReadHalf<T>
240where
241    T: AsyncRead + Unpin,
242{
243    async fn pull(&mut self) -> Res<ProtocolMessageGuard> {
244        // Use reserve here to make sure we have at least the space for the next read.
245        // The `Bytes` go _with_ the returned guard, so we need to make sure we have enough space
246        // for the next read.
247        //
248        // In many cases, the guards may have been dropped, and reserve will not allocate,
249        // but we need to make sure we have enough space for the next read in case they haven't
250        //
251        // In practice, this is used in the pump, so guards are usually dropped within a
252        // few reads, but not after _every_ read (which is why we don't use `try_reclaim` here).
253
254        self.buffer.clear();
255        self.decryption_buffer.clear();
256        self.buffer.reserve(Constant::BUFFER_SIZE);
257        self.decryption_buffer.reserve(Constant::BUFFER_SIZE);
258
259        // First, read the encryption flag.
260
261        let is_encrypted = self.inner.read_u8().await.context("Failed to read encryption flag")? == 1;
262
263        // Next, read the nonce, if the message is encrypted.
264
265        let maybe_encryption_data = if is_encrypted {
266            let mut nonce = [0; Constant::SHARED_SECRET_NONCE_SIZE];
267            self.inner.read_exact(&mut nonce).await.context("Failed to read nonce")?;
268
269            Some(nonce)
270        } else {
271            None
272        };
273
274        // Read the size of the message from the stream.
275
276        let message_size = self.inner.read_u64().await.context("Failed to read size")? as usize;
277
278        // Bail if we got a FIN (?) or a message that is too large.
279
280        if message_size == 0 {
281            let guard = ProtocolMessageGuardBuilder {
282                buffer: Bytes::new(),
283                inner_builder: |_| ProtocolMessage::Shutdown,
284            }
285            .build();
286
287            return Ok(guard);
288        }
289
290        if message_size > Constant::BUFFER_SIZE {
291            return Err(anyhow!("Message size is too large for the buffer during pull"));
292        }
293
294        // Read the stream into the buffer.
295
296        // SAFTY: We know _exactly_ how many bytes we are going to read, so we can safely
297        // set the length of the buffer to the size of the message.  However, we check
298        // afterward just to make sure.
299        unsafe { self.buffer.set_len(message_size) };
300        let n = self.inner.read_exact(&mut self.buffer).await?;
301
302        if message_size != n {
303            return Err(anyhow!("Failed to read message: expected {} bytes, got {}", message_size, n));
304        }
305
306        // Split off the needed bytes for the borrowed message.
307
308        let data_buffer = self.buffer.split().freeze();
309
310        // Perform any needed encryption.
311
312        let data_buffer = if let Some(nonce) = maybe_encryption_data {
313            let Some(key) = self.shared_secret.as_ref() else {
314                return Err(anyhow!("Shared secret is not set when receiving encrypted message"));
315            };
316
317            self.decryption_buffer.put(data_buffer);
318
319            // The length is of the buffer is adjusted by this call.
320            decrypt_in_place(key, &nonce, &mut self.decryption_buffer).context("Failed to decrypt message")?;
321
322            self.decryption_buffer.split().freeze()
323        } else {
324            data_buffer
325        };
326
327        // Prepare the result into the guard.
328
329        Ok(ProtocolMessageGuardBuilder {
330            buffer: data_buffer,
331            inner_builder: |data| match bincode::borrow_decode_from_slice::<ProtocolMessage<'_>, _>(data, Constant::BINCODE_CONFIG) {
332                Ok((message, n)) => {
333                    if n != data.len() {
334                        error!("Failed to decrypt message: expected {} bytes, got {}", data.len(), n);
335                        return ProtocolMessage::Shutdown;
336                    }
337
338                    message
339                }
340                Err(e) => {
341                    error!("Failed to decode message: {}", e);
342                    ProtocolMessage::Shutdown
343                }
344            },
345        }
346        .build())
347    }
348}
349
350/// A type for the write half of a buffed stream.
351pub struct BuffedStreamWriteHalf<T> {
352    inner: T,
353    shared_secret: Option<SharedSecret>,
354    buffer: BytesMut,
355}
356
357impl<T> BuffedStreamWriteHalf<T>
358where
359    T: AsyncWrite + Unpin,
360{
361    /// Creates a new `BuffedStreamWriteHalf` from the given [`AsyncWrite`].
362    fn new(async_write: T) -> Self {
363        Self {
364            inner: async_write,
365            shared_secret: None,
366            buffer: BytesMut::with_capacity(2 * Constant::BUFFER_SIZE),
367        }
368    }
369
370    /// Takes the inner stream and returns it.
371    fn take(self) -> T {
372        self.inner
373    }
374}
375
376impl<T> BincodeSend for BuffedStreamWriteHalf<T>
377where
378    T: AsyncWrite + Unpin,
379{
380    async fn push<E>(&mut self, message: E) -> Void
381    where
382        E: Encode,
383    {
384        // Restore the buffer to its original state.
385        // We use `try_reclaim` here to avoid allocating a new buffer, and there
386        // _should never be a case where we cannot.  The produced interim `Bytes` are all dropped
387        // at the end of the function, and we currently have a unique mutable borrow, so there cannot
388        // be any other references to the buffer.
389
390        self.buffer.clear();
391        assert!(self.buffer.try_reclaim(2 * Constant::BUFFER_SIZE));
392
393        // Encode the message into the buffer.
394
395        // SAFETY: We know the size of the buffer, and we are going to fill it with
396        // the encoded message.  We also know that the buffer is empty, so we can safely
397        // set the length of the buffer to the size of the message.
398        unsafe { self.buffer.set_len(Constant::BUFFER_SIZE) };
399        let n = bincode::encode_into_slice(message, &mut self.buffer, Constant::BINCODE_CONFIG)?;
400        unsafe { self.buffer.set_len(n) };
401
402        let maybe_nonce = if let Some(key) = self.shared_secret.as_ref() {
403            // This call extends the buffer through `Extend`, so no need to update the length.
404            let nonce = encrypt_into(key, &mut self.buffer).context("Encryption failed")?;
405
406            Some(nonce)
407        } else {
408            None
409        };
410
411        // Ensure the buffer is not empty and is not too large.
412
413        let data_length = self.buffer.len();
414
415        if data_length == 0 {
416            return Err(anyhow!("Buffer is empty"));
417        }
418
419        if data_length > Constant::BUFFER_SIZE {
420            return Err(anyhow!("Buffer is too large"));
421        }
422
423        // Write enryption / nonce.
424
425        if let Some(nonce) = maybe_nonce {
426            self.inner.write_u8(1).await.context("Failed to write encryption flag")?;
427            self.inner.write_all(&nonce).await.context("Failed to write nonce")?;
428        } else {
429            self.inner.write_u8(0).await.context("Failed to write encryption flag")?;
430        }
431
432        // Write the message size.
433
434        self.inner.write_u64(data_length as u64).await.context("Failed to write size")?;
435
436        // Write the data.
437
438        self.inner.write_all(&self.buffer).await?;
439
440        // Flush the stream.
441
442        self.inner.flush().await.context("Failed to flush stream")?;
443
444        Ok(())
445    }
446
447    async fn close(&mut self) -> Void {
448        self.inner.shutdown().await.context("Failed to close stream")?;
449
450        Ok(())
451    }
452}
453
454// Tests.
455
456#[cfg(test)]
457mod tests {
458    use crate::{
459        protocol::{BincodeReceive, BincodeSend, ProtocolMessage},
460        utils::tests::{generate_test_duplex, generate_test_duplex_with_encryption},
461    };
462
463    #[tokio::test]
464    async fn test_unencrypted_buffed_stream() {
465        let (mut client, mut server) = generate_test_duplex();
466
467        let data = b"Hello, world!";
468
469        client.push(ProtocolMessage::Data(data)).await.unwrap();
470        client.close().await.unwrap();
471
472        let guard = server.pull().await.unwrap();
473        let ProtocolMessage::Data(received) = *guard.message() else {
474            panic!("Failed to receive message");
475        };
476
477        assert_eq!(data, received);
478    }
479
480    #[tokio::test]
481    async fn test_e2e_encrypted_buffed_stream() {
482        let (mut client, mut server) = generate_test_duplex_with_encryption();
483
484        let data = b"Hello, world!";
485
486        client.push(ProtocolMessage::Data(data)).await.unwrap();
487        client.close().await.unwrap();
488
489        let guard = server.pull().await.unwrap();
490        let ProtocolMessage::Data(received) = *guard.message() else {
491            panic!("Failed to receive message");
492        };
493
494        assert_eq!(data, received);
495    }
496
497    #[tokio::test]
498    async fn test_e2e_encrypted_buffed_stream_with_multiple_packets() {
499        let (mut client, mut server) = generate_test_duplex_with_encryption();
500
501        let data1 = b"Hello, world!";
502        let data2 = b"Hello, wold!";
503
504        client.push(ProtocolMessage::Data(data1)).await.unwrap();
505        client.push(ProtocolMessage::Data(data2)).await.unwrap();
506        client.close().await.unwrap();
507
508        let guard = server.pull().await.unwrap();
509        let ProtocolMessage::Data(received) = *guard.message() else {
510            panic!("Failed to receive message");
511        };
512        assert_eq!(data1, received);
513
514        let guard = server.pull().await.unwrap();
515        let ProtocolMessage::Data(received) = *guard.message() else {
516            panic!("Failed to receive message");
517        };
518        assert_eq!(data2, received);
519    }
520}