ratrodlib/
buffed_stream.rs

1//! Buffed stream module.
2//!
3//! This module contains the `BuffedStream` type, which is a wrapper around a stream that provides
4//! buffering and encryption/decryption functionality.
5//!
6//! It is used to provide a bincode-centric stream that can be used to send and receive data
7//! in a more efficient manner.  In addition, the `AsyncRead` and `AsyncWrite` implementations
8//! are designed to "transparently" handle encryption and decryption of the data being sent
9//! and received (for the "pump" phase of the lifecycle).
10
11use anyhow::{Context as _, anyhow};
12use bincode::Encode;
13use bytes::{BufMut, Bytes, BytesMut};
14use secrecy::ExposeSecret;
15use tokio::{
16    io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, DuplexStream, ReadHalf, WriteHalf},
17    net::{
18        TcpStream,
19        tcp::{OwnedReadHalf, OwnedWriteHalf},
20    },
21};
22use tracing::{error, warn};
23
24use crate::{
25    base::{Constant, Res, SharedSecret, Void},
26    protocol::{BincodeReceive, BincodeSend, ProtocolMessage, ProtocolMessageGuard, ProtocolMessageGuardBuilder},
27    utils::{decrypt_in_place, encrypt_into},
28};
29
30// Traits.
31
32/// A trait for splitting a [`BuffedStream`] into its read and write halves.
33pub trait BincodeSplit {
34    type ReadHalf: BincodeReceive;
35    type WriteHalf: BincodeSend;
36
37    /// Takes and splits the buffered stream into its read and write halves.
38    ///
39    /// This allows the read and write halves to be used independently, potentially
40    /// from different tasks or threads.
41    fn into_split(self) -> (Self::ReadHalf, Self::WriteHalf);
42
43    /// Splits the buffered stream into mutably borrowed read and write halves.
44    ///
45    /// This allows the read and write halves to be used independently, potentially
46    /// from different tasks or threads.
47    fn split(&mut self) -> (&mut Self::ReadHalf, &mut Self::WriteHalf);
48}
49
50// Types.
51
52/// A type alias for a buffed [`TcpStream`].
53pub type BuffedTcpStream = BuffedStream<OwnedReadHalf, OwnedWriteHalf>;
54/// A type alias for a buffed [`DuplexStream`].
55pub type BuffedDuplexStream = BuffedStream<ReadHalf<DuplexStream>, WriteHalf<DuplexStream>>;
56
57/// BuffedStream type.
58///
59/// This type is a wrapper around a stream that provides buffering and encryption/decryption functionality.
60/// It is used to provide a bincode-centric stream that can be used to send and receive data
61/// in a more efficient manner.  In order to make _usual_ future splitting more ergonomic, this type
62/// is designed to be a wrapper around the split halves.
63///
64/// > This type is used to provide a bincode-centric stream that can be used to send and receive data
65/// > so it is inadvisable to use any other methods than the `push` and `pull` methods from the protocol
66/// > module.
67pub struct BuffedStream<R, W> {
68    /// The read half of the buffered stream
69    inner_read: BuffedStreamReadHalf<R>,
70    /// The write half of the buffered stream
71    inner_write: BuffedStreamWriteHalf<W>,
72}
73
74// Impl.
75
76impl<R, W> BuffedStream<R, W> {
77    /// Sets the shared secret for the stream, and enables encryption / decryption.
78    pub fn with_encryption(mut self, shared_secret: SharedSecret) -> Self {
79        let secret_clone = SharedSecret::init_with(|| *shared_secret.expose_secret());
80
81        self.inner_read.shared_secret = Some(secret_clone);
82        self.inner_write.shared_secret = Some(shared_secret);
83
84        self
85    }
86}
87
88impl<R, W> BincodeSplit for BuffedStream<R, W>
89where
90    R: AsyncRead + Unpin,
91    W: AsyncWrite + Unpin,
92{
93    type ReadHalf = BuffedStreamReadHalf<R>;
94    type WriteHalf = BuffedStreamWriteHalf<W>;
95
96    fn into_split(self) -> (Self::ReadHalf, Self::WriteHalf) {
97        (self.inner_read, self.inner_write)
98    }
99
100    fn split(&mut self) -> (&mut Self::ReadHalf, &mut Self::WriteHalf) {
101        (&mut self.inner_read, &mut self.inner_write)
102    }
103}
104
105impl From<TcpStream> for BuffedStream<OwnedReadHalf, OwnedWriteHalf> {
106    fn from(stream: TcpStream) -> Self {
107        let (read, write) = stream.into_split();
108
109        Self {
110            inner_read: BuffedStreamReadHalf::new(read),
111            inner_write: BuffedStreamWriteHalf::new(write),
112        }
113    }
114}
115
116impl<T> From<T> for BuffedStream<ReadHalf<T>, WriteHalf<T>>
117where
118    T: AsyncRead + AsyncWrite + Unpin,
119{
120    fn from(stream: T) -> Self {
121        let (read, write) = tokio::io::split(stream);
122
123        Self {
124            inner_read: BuffedStreamReadHalf::new(read),
125            inner_write: BuffedStreamWriteHalf::new(write),
126        }
127    }
128}
129
130impl<R, W> BuffedStream<R, W>
131where
132    R: AsyncRead + Unpin,
133    W: AsyncWrite + Unpin,
134{
135    pub fn from_splits(inner_read: R, inner_write: W) -> Self {
136        Self {
137            inner_read: BuffedStreamReadHalf::new(inner_read),
138            inner_write: BuffedStreamWriteHalf::new(inner_write),
139        }
140    }
141}
142
143impl<R> BuffedStream<R, OwnedWriteHalf> {
144    pub fn as_inner_tcp_write_ref(&self) -> &OwnedWriteHalf {
145        &self.inner_write.inner
146    }
147
148    pub fn as_inner_tcp_write_mut(&mut self) -> &mut OwnedWriteHalf {
149        &mut self.inner_write.inner
150    }
151}
152
153impl<W> BuffedStream<OwnedReadHalf, W> {
154    pub fn as_inner_tcp_read_ref(&self) -> &OwnedReadHalf {
155        &self.inner_read.inner
156    }
157
158    pub fn as_inner_tcp_read_mut(&mut self) -> &mut OwnedReadHalf {
159        &mut self.inner_read.inner
160    }
161}
162
163impl BuffedStream<OwnedReadHalf, OwnedWriteHalf> {
164    pub fn take(self) -> Res<TcpStream> {
165        let read = self.inner_read.take();
166        let write = self.inner_write.take();
167
168        read.reunite(write).context("Failed to reunite read and write halves")
169    }
170}
171
172// Trait impls.
173
174impl<R, W> BincodeSend for BuffedStream<R, W>
175where
176    R: Unpin,
177    W: AsyncWrite + Unpin,
178{
179    async fn push<E>(&mut self, message: E) -> Void
180    where
181        E: Encode,
182    {
183        self.inner_write.push(message).await?;
184
185        Ok(())
186    }
187
188    async fn close(&mut self) -> Void {
189        self.inner_write.close().await?;
190
191        Ok(())
192    }
193}
194
195impl<R, W> BincodeReceive for BuffedStream<R, W>
196where
197    R: AsyncRead + Unpin,
198    W: Unpin,
199{
200    async fn pull(&mut self) -> Res<ProtocolMessageGuard> {
201        self.inner_read.pull().await
202    }
203}
204
205// Split streams.
206
207/// A type for the read half of a buffed stream.
208pub struct BuffedStreamReadHalf<T> {
209    inner: T,
210    shared_secret: Option<SharedSecret>,
211    buffer: BytesMut,
212    decryption_buffer: BytesMut,
213}
214
215impl<T> BuffedStreamReadHalf<T>
216where
217    T: AsyncRead + Unpin,
218{
219    /// Creates a new `BuffedStreamReadHalf` from the given [`AsyncRead`].
220    fn new(async_read: T) -> Self {
221        Self {
222            inner: async_read,
223            shared_secret: None,
224            buffer: BytesMut::with_capacity(2 * Constant::BUFFER_SIZE),
225            decryption_buffer: BytesMut::with_capacity(2 * Constant::BUFFER_SIZE),
226        }
227    }
228
229    /// Takes the inner stream and returns it.
230    fn take(self) -> T {
231        if !self.buffer.is_empty() {
232            warn!("Buffer was not empty when taking the stream");
233        }
234
235        self.inner
236    }
237}
238
239impl<T> BincodeReceive for BuffedStreamReadHalf<T>
240where
241    T: AsyncRead + Unpin,
242{
243    async fn pull(&mut self) -> Res<ProtocolMessageGuard> {
244        // Use reserve here to make sure we have at least the space for the next read.
245        // The `Bytes` go _with_ the returned guard, so we need to make sure we have enough space
246        // for the next read.
247        //
248        // In many cases, the guards may have been dropped, and reserve will not allocate,
249        // but we need to make sure we have enough space for the next read in case they haven't.
250        //
251        // In practice, this is used in the pump, so guards are usually dropped within a
252        // few reads, but not after _every_ read (which is why we don't use `try_reclaim` here).
253
254        self.buffer.clear();
255        self.decryption_buffer.clear();
256        self.buffer.reserve(Constant::BUFFER_SIZE);
257        self.decryption_buffer.reserve(Constant::BUFFER_SIZE);
258
259        // First, read the encryption flag.
260
261        let is_encrypted = match self.inner.read_u8().await {
262            Ok(1) => true,
263            Ok(_) => false,
264            Err(e) => {
265                if e.kind() == std::io::ErrorKind::UnexpectedEof {
266                    return Ok(ProtocolMessageGuardBuilder {
267                        buffer: Bytes::new(),
268                        inner_builder: |_| ProtocolMessage::Shutdown,
269                    }
270                    .build());
271                } else {
272                    return Err(anyhow!("Failed to read encryption flag: {}", e));
273                }
274            }
275        };
276
277        // Next, read the nonce, if the message is encrypted.
278
279        let maybe_encryption_data = if is_encrypted {
280            let mut nonce = [0; Constant::SHARED_SECRET_NONCE_SIZE];
281            self.inner.read_exact(&mut nonce).await.context("Failed to read nonce")?;
282
283            Some(nonce)
284        } else {
285            None
286        };
287
288        // Read the size of the message from the stream.
289
290        let message_size = self.inner.read_u64().await.context("Failed to read size")? as usize;
291
292        // Bail if we got a FIN (?) or a message that is too large.
293
294        if message_size == 0 {
295            let guard = ProtocolMessageGuardBuilder {
296                buffer: Bytes::new(),
297                inner_builder: |_| ProtocolMessage::Shutdown,
298            }
299            .build();
300
301            return Ok(guard);
302        }
303
304        if message_size > Constant::BUFFER_SIZE {
305            return Err(anyhow!("Message size is too large for the buffer during pull"));
306        }
307
308        // Read the stream into the buffer.
309
310        // SAFTY: We know _exactly_ how many bytes we are going to read, so we can safely
311        // set the length of the buffer to the size of the message.  However, we check
312        // afterward just to make sure.
313        unsafe { self.buffer.set_len(message_size) };
314        let n = self.inner.read_exact(&mut self.buffer).await?;
315
316        if message_size != n {
317            return Err(anyhow!("Failed to read message: expected {} bytes, got {}", message_size, n));
318        }
319
320        // Split off the needed bytes for the borrowed message.
321
322        let data_buffer = self.buffer.split().freeze();
323
324        // Perform any needed encryption.
325
326        let data_buffer = if let Some(nonce) = maybe_encryption_data {
327            let Some(key) = self.shared_secret.as_ref() else {
328                return Err(anyhow!("Shared secret is not set when receiving encrypted message"));
329            };
330
331            self.decryption_buffer.put(data_buffer);
332
333            // The length is of the buffer is adjusted by this call.
334            decrypt_in_place(key, &nonce, &mut self.decryption_buffer).context("Failed to decrypt message")?;
335
336            self.decryption_buffer.split().freeze()
337        } else {
338            data_buffer
339        };
340
341        // Prepare the result into the guard.
342
343        Ok(ProtocolMessageGuardBuilder {
344            buffer: data_buffer,
345            inner_builder: |data| match bincode::borrow_decode_from_slice::<ProtocolMessage<'_>, _>(data, Constant::BINCODE_CONFIG) {
346                Ok((message, n)) => {
347                    if n != data.len() {
348                        error!("Failed to decrypt message: expected {} bytes, got {}", data.len(), n);
349                        return ProtocolMessage::Shutdown;
350                    }
351
352                    message
353                }
354                Err(e) => {
355                    error!("Failed to decode message: {}", e);
356                    ProtocolMessage::Shutdown
357                }
358            },
359        }
360        .build())
361    }
362}
363
364/// A type for the write half of a buffed stream.
365pub struct BuffedStreamWriteHalf<T> {
366    inner: T,
367    shared_secret: Option<SharedSecret>,
368    buffer: BytesMut,
369}
370
371impl<T> BuffedStreamWriteHalf<T>
372where
373    T: AsyncWrite + Unpin,
374{
375    /// Creates a new `BuffedStreamWriteHalf` from the given [`AsyncWrite`].
376    fn new(async_write: T) -> Self {
377        Self {
378            inner: async_write,
379            shared_secret: None,
380            buffer: BytesMut::with_capacity(2 * Constant::BUFFER_SIZE),
381        }
382    }
383
384    /// Takes the inner stream and returns it.
385    fn take(self) -> T {
386        self.inner
387    }
388}
389
390impl<T> BincodeSend for BuffedStreamWriteHalf<T>
391where
392    T: AsyncWrite + Unpin,
393{
394    async fn push<E>(&mut self, message: E) -> Void
395    where
396        E: Encode,
397    {
398        // Restore the buffer to its original state.
399        // We use `try_reclaim` here to avoid allocating a new buffer, and there
400        // _should_ never be a case where we cannot.  The produced interim `Bytes` are all dropped
401        // at the end of the function, and we currently have a unique mutable borrow, so there cannot
402        // be any other references to the buffer.
403
404        self.buffer.clear();
405        assert!(self.buffer.try_reclaim(2 * Constant::BUFFER_SIZE));
406
407        // Encode the message into the buffer.
408
409        // SAFETY: We know the size of the buffer, and we are going to fill it with
410        // the encoded message.  We also know that the buffer is empty, so we can safely
411        // set the length of the buffer to the size of the message.
412        unsafe { self.buffer.set_len(Constant::BUFFER_SIZE) };
413        let n = bincode::encode_into_slice(message, &mut self.buffer, Constant::BINCODE_CONFIG)?;
414        unsafe { self.buffer.set_len(n) };
415
416        // Encrypt the message if, needed.
417
418        let maybe_nonce = if let Some(key) = self.shared_secret.as_ref() {
419            // This call extends the buffer through `Extend`, so no need to update the length.
420            let nonce = encrypt_into(key, &mut self.buffer).context("Encryption failed")?;
421
422            Some(nonce)
423        } else {
424            None
425        };
426
427        // Ensure the buffer is not empty and is not too large.
428
429        let data_length = self.buffer.len();
430
431        if data_length == 0 {
432            return Err(anyhow!("Buffer is empty"));
433        }
434
435        if data_length > Constant::BUFFER_SIZE {
436            return Err(anyhow!("Buffer is too large"));
437        }
438
439        // Write enryption / nonce.
440
441        if let Some(nonce) = maybe_nonce {
442            self.inner.write_u8(1).await.context("Failed to write encryption flag")?;
443            self.inner.write_all(&nonce).await.context("Failed to write nonce")?;
444        } else {
445            self.inner.write_u8(0).await.context("Failed to write encryption flag")?;
446        }
447
448        // Write the message size.
449
450        self.inner.write_u64(data_length as u64).await.context("Failed to write size")?;
451
452        // Write the data.
453
454        self.inner.write_all(&self.buffer).await?;
455
456        // Flush the stream.
457
458        self.inner.flush().await.context("Failed to flush stream")?;
459
460        Ok(())
461    }
462
463    async fn close(&mut self) -> Void {
464        self.inner.shutdown().await.context("Failed to close stream")?;
465
466        Ok(())
467    }
468}
469
470// Tests.
471
472#[cfg(test)]
473mod tests {
474    use crate::{
475        protocol::{BincodeReceive, BincodeSend, ProtocolMessage},
476        utils::tests::{generate_test_duplex, generate_test_duplex_with_encryption},
477    };
478
479    #[tokio::test]
480    async fn test_unencrypted_buffed_stream() {
481        let (mut client, mut server) = generate_test_duplex();
482
483        let data = b"Hello, world!";
484
485        client.push(ProtocolMessage::Data(data)).await.unwrap();
486        client.close().await.unwrap();
487
488        let guard = server.pull().await.unwrap();
489        let ProtocolMessage::Data(received) = *guard.message() else {
490            panic!("Failed to receive message");
491        };
492
493        assert_eq!(data, received);
494    }
495
496    #[tokio::test]
497    async fn test_e2e_encrypted_buffed_stream() {
498        let (mut client, mut server) = generate_test_duplex_with_encryption();
499
500        let data = b"Hello, world!";
501
502        client.push(ProtocolMessage::Data(data)).await.unwrap();
503        client.close().await.unwrap();
504
505        let guard = server.pull().await.unwrap();
506        let ProtocolMessage::Data(received) = *guard.message() else {
507            panic!("Failed to receive message");
508        };
509
510        assert_eq!(data, received);
511    }
512
513    #[tokio::test]
514    async fn test_e2e_encrypted_buffed_stream_with_multiple_packets() {
515        let (mut client, mut server) = generate_test_duplex_with_encryption();
516
517        let data1 = b"Hello, world!";
518        let data2 = b"Hello, wold!";
519
520        client.push(ProtocolMessage::Data(data1)).await.unwrap();
521        client.push(ProtocolMessage::Data(data2)).await.unwrap();
522        client.close().await.unwrap();
523
524        let guard = server.pull().await.unwrap();
525        let ProtocolMessage::Data(received) = *guard.message() else {
526            panic!("Failed to receive message");
527        };
528        assert_eq!(data1, received);
529
530        let guard = server.pull().await.unwrap();
531        let ProtocolMessage::Data(received) = *guard.message() else {
532            panic!("Failed to receive message");
533        };
534        assert_eq!(data2, received);
535    }
536}