distant_net/common/transport/
framed.rs

1use std::future::Future;
2use std::time::Duration;
3use std::{fmt, io};
4
5use async_trait::async_trait;
6use bytes::{Buf, BytesMut};
7use log::*;
8use serde::de::DeserializeOwned;
9use serde::{Deserialize, Serialize};
10
11use super::{InmemoryTransport, Interest, Ready, Reconnectable, Transport};
12use crate::common::{utils, SecretKey32};
13
14mod backup;
15mod codec;
16mod exchange;
17mod frame;
18mod handshake;
19
20pub use backup::*;
21pub use codec::*;
22pub use exchange::*;
23pub use frame::*;
24pub use handshake::*;
25
26/// Size of the read buffer when reading bytes to construct a frame
27const READ_BUF_SIZE: usize = 8 * 1024;
28
29/// Duration to wait after WouldBlock received during looping operations like `read_frame`
30const SLEEP_DURATION: Duration = Duration::from_millis(1);
31
32/// Represents a wrapper around a [`Transport`] that reads and writes using frames defined by a
33/// [`Codec`].
34///
35/// [`try_read`]: Transport::try_read
36#[derive(Clone)]
37pub struct FramedTransport<T> {
38    /// Inner transport wrapped to support frames of data
39    inner: T,
40
41    /// Codec used to encoding outgoing bytes and decode incoming bytes
42    codec: BoxedCodec,
43
44    /// Bytes in queue to be read
45    incoming: BytesMut,
46
47    /// Bytes in queue to be written
48    outgoing: BytesMut,
49
50    /// Stores outgoing frames in case of transmission issues
51    pub backup: Backup,
52}
53
54impl<T> FramedTransport<T> {
55    pub fn new(inner: T, codec: BoxedCodec) -> Self {
56        Self {
57            inner,
58            codec,
59            incoming: BytesMut::with_capacity(READ_BUF_SIZE * 2),
60            outgoing: BytesMut::with_capacity(READ_BUF_SIZE * 2),
61            backup: Backup::new(),
62        }
63    }
64
65    /// Creates a new [`FramedTransport`] using the [`PlainCodec`]
66    pub fn plain(inner: T) -> Self {
67        Self::new(inner, Box::new(PlainCodec::new()))
68    }
69
70    /// Replaces the current codec with the provided codec. Note that any bytes in the incoming or
71    /// outgoing buffers will remain in the transport, meaning that this can cause corruption if
72    /// the bytes in the buffers do not match the new codec.
73    ///
74    /// For safety, use [`clear`] to wipe the buffers before further use.
75    ///
76    /// [`clear`]: FramedTransport::clear
77    pub fn set_codec(&mut self, codec: BoxedCodec) {
78        self.codec = codec;
79    }
80
81    /// Returns a reference to the codec used by the transport.
82    ///
83    /// ### Note
84    ///
85    /// Be careful when accessing the codec to avoid corrupting it through unexpected modifications
86    /// as this will place the transport in an undefined state.
87    pub fn codec(&self) -> &dyn Codec {
88        self.codec.as_ref()
89    }
90
91    /// Returns a mutable reference to the codec used by the transport.
92    ///
93    /// ### Note
94    ///
95    /// Be careful when accessing the codec to avoid corrupting it through unexpected modifications
96    /// as this will place the transport in an undefined state.
97    pub fn mut_codec(&mut self) -> &mut dyn Codec {
98        self.codec.as_mut()
99    }
100
101    /// Clears the internal transport buffers.
102    pub fn clear(&mut self) {
103        self.incoming.clear();
104        self.outgoing.clear();
105    }
106
107    /// Returns a reference to the inner value this transport wraps.
108    pub fn as_inner(&self) -> &T {
109        &self.inner
110    }
111
112    /// Returns a mutable reference to the inner value this transport wraps.
113    pub fn as_mut_inner(&mut self) -> &mut T {
114        &mut self.inner
115    }
116
117    /// Consumes this transport, returning the inner value that it wraps.
118    pub fn into_inner(self) -> T {
119        self.inner
120    }
121}
122
123impl<T> fmt::Debug for FramedTransport<T> {
124    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
125        f.debug_struct("FramedTransport")
126            .field("incoming", &self.incoming)
127            .field("outgoing", &self.outgoing)
128            .field("backup", &self.backup)
129            .finish()
130    }
131}
132
133impl<T: Transport + 'static> FramedTransport<T> {
134    /// Converts this instance to a [`FramedTransport`] whose inner [`Transport`] is [`Box`]ed.
135    pub fn into_boxed(self) -> FramedTransport<Box<dyn Transport>> {
136        FramedTransport {
137            inner: Box::new(self.inner),
138            codec: self.codec,
139            incoming: self.incoming,
140            outgoing: self.outgoing,
141            backup: self.backup,
142        }
143    }
144}
145
146impl<T: Transport> FramedTransport<T> {
147    /// Waits for the transport to be ready based on the given interest, returning the ready status
148    pub async fn ready(&self, interest: Interest) -> io::Result<Ready> {
149        // If interest includes reading, we check if we already have a frame in our queue,
150        // as there can be a scenario where a frame was received and then the connection
151        // was closed, and we still want to be able to read the next frame is if it is
152        // available in the connection.
153        let ready = if interest.is_readable() && Frame::available(&self.incoming) {
154            Ready::READABLE
155        } else {
156            Ready::EMPTY
157        };
158
159        // If we know that we are readable and not checking for write status, we can short-circuit
160        // to avoid an async call by returning immediately that we are readable
161        if !interest.is_writable() && ready.is_readable() {
162            return Ok(ready);
163        }
164
165        // Otherwise, we need to check the status using the underlying transport and merge it with
166        // our current understanding based on internal state
167        Transport::ready(&self.inner, interest)
168            .await
169            .map(|r| r | ready)
170    }
171
172    /// Waits for the transport to be readable to follow up with [`try_read_frame`].
173    ///
174    /// [`try_read_frame`]: FramedTransport::try_read_frame
175    pub async fn readable(&self) -> io::Result<()> {
176        let _ = self.ready(Interest::READABLE).await?;
177        Ok(())
178    }
179
180    /// Waits for the transport to be writeable to follow up with [`try_write_frame`].
181    ///
182    /// [`try_write_frame`]: FramedTransport::try_write_frame
183    pub async fn writeable(&self) -> io::Result<()> {
184        let _ = self.ready(Interest::WRITABLE).await?;
185        Ok(())
186    }
187
188    /// Waits for the transport to be readable or writeable, returning the [`Ready`] status.
189    pub async fn readable_or_writeable(&self) -> io::Result<Ready> {
190        self.ready(Interest::READABLE | Interest::WRITABLE).await
191    }
192
193    /// Attempts to flush any remaining bytes in the outgoing queue, returning the total bytes
194    /// written as a result of the flush. Note that a return of 0 bytes does not indicate that the
195    /// underlying transport has closed, but rather that no bytes were flushed such as when the
196    /// outgoing queue is empty.
197    ///
198    /// This is accomplished by continually calling the inner transport's `try_write`. If 0 is
199    /// returned from a call to `try_write`, this will fail with [`ErrorKind::WriteZero`].
200    ///
201    /// This call may return an error with [`ErrorKind::WouldBlock`] in the case that the transport
202    /// is not ready to write data.
203    ///
204    /// [`ErrorKind::WouldBlock`]: io::ErrorKind::WouldBlock
205    pub fn try_flush(&mut self) -> io::Result<usize> {
206        let mut bytes_written = 0;
207
208        // Continue to send from the outgoing buffer until we either finish or fail
209        while !self.outgoing.is_empty() {
210            match self.inner.try_write(self.outgoing.as_ref()) {
211                // Getting 0 bytes on write indicates the channel has closed
212                Ok(0) => return Err(io::Error::from(io::ErrorKind::WriteZero)),
213
214                // Successful write will advance the outgoing buffer
215                Ok(n) => {
216                    self.outgoing.advance(n);
217                    bytes_written += n;
218                }
219
220                // Any error (including WouldBlock) will get bubbled up
221                Err(x) => return Err(x),
222            }
223        }
224
225        Ok(bytes_written)
226    }
227
228    /// Flushes all buffered, outgoing bytes using repeated calls to [`try_flush`].
229    ///
230    /// [`try_flush`]: FramedTransport::try_flush
231    pub async fn flush(&mut self) -> io::Result<()> {
232        while !self.outgoing.is_empty() {
233            self.writeable().await?;
234            match self.try_flush() {
235                Err(x) if x.kind() == io::ErrorKind::WouldBlock => {
236                    // NOTE: We sleep for a little bit before trying again to avoid pegging CPU
237                    tokio::time::sleep(SLEEP_DURATION).await
238                }
239                Err(x) => return Err(x),
240                Ok(_) => return Ok(()),
241            }
242        }
243
244        Ok(())
245    }
246
247    /// Reads a frame of bytes by using the [`Codec`] tied to this transport. Returns
248    /// `Ok(Some(frame))` upon reading a frame, or `Ok(None)` if the underlying transport has
249    /// closed.
250    ///
251    /// This call may return an error with [`ErrorKind::WouldBlock`] in the case that the transport
252    /// is not ready to read data or has not received a full frame before waiting.
253    ///
254    /// [`ErrorKind::WouldBlock`]: io::ErrorKind::WouldBlock
255    pub fn try_read_frame(&mut self) -> io::Result<Option<OwnedFrame>> {
256        // Attempt to read a frame, returning the decoded frame if we get one, returning any error
257        // that is encountered from reading frames or failing to decode, or otherwise doing nothing
258        // and continuing forward.
259        macro_rules! read_next_frame {
260            () => {{
261                match Frame::read(&mut self.incoming) {
262                    None => (),
263                    Some(frame) => {
264                        if frame.is_nonempty() {
265                            self.backup.increment_received_cnt();
266                        }
267                        return Ok(Some(self.codec.decode(frame)?.into_owned()));
268                    }
269                }
270            }};
271        }
272
273        // If we have data remaining in the buffer, we first try to parse it in case we received
274        // multiple frames from a previous call.
275        //
276        // NOTE: This exists to avoid the situation where there is a valid frame remaining in the
277        //       incoming buffer, but it is never evaluated because a call to `try_read` returns
278        //       `WouldBlock`, 0 bytes, or some other error.
279        if !self.incoming.is_empty() {
280            read_next_frame!();
281        }
282
283        // Continually read bytes into the incoming queue and then attempt to tease out a frame
284        let mut buf = [0; READ_BUF_SIZE];
285
286        loop {
287            match self.inner.try_read(&mut buf) {
288                // Getting 0 bytes on read indicates the channel has closed. If we were still
289                // expecting more bytes for our frame, then this is an error, otherwise if we
290                // have nothing remaining if our queue then this is an expected end and we
291                // return None
292                Ok(0) if self.incoming.is_empty() => return Ok(None),
293                Ok(0) => return Err(io::Error::from(io::ErrorKind::UnexpectedEof)),
294
295                // Got some additional bytes, which we will add to our queue and then attempt to
296                // decode into a frame
297                Ok(n) => {
298                    self.incoming.extend_from_slice(&buf[..n]);
299                    read_next_frame!();
300                }
301
302                // Any error (including WouldBlock) will get bubbled up
303                Err(x) => return Err(x),
304            }
305        }
306    }
307
308    /// Reads a frame using [`try_read_frame`] and then deserializes the bytes into `D`.
309    ///
310    /// [`try_read_frame`]: FramedTransport::try_read_frame
311    pub fn try_read_frame_as<D: DeserializeOwned>(&mut self) -> io::Result<Option<D>> {
312        match self.try_read_frame() {
313            Ok(Some(frame)) => Ok(Some(utils::deserialize_from_slice(frame.as_item())?)),
314            Ok(None) => Ok(None),
315            Err(x) => Err(x),
316        }
317    }
318
319    /// Continues to invoke [`try_read_frame`] until a frame is successfully read, an error is
320    /// encountered that is not [`ErrorKind::WouldBlock`], or the underlying transport has closed.
321    ///
322    /// [`try_read_frame`]: FramedTransport::try_read_frame
323    /// [`ErrorKind::WouldBlock`]: io::ErrorKind::WouldBlock
324    pub async fn read_frame(&mut self) -> io::Result<Option<OwnedFrame>> {
325        loop {
326            self.readable().await?;
327
328            match self.try_read_frame() {
329                Err(x) if x.kind() == io::ErrorKind::WouldBlock => {
330                    // NOTE: We sleep for a little bit before trying again to avoid pegging CPU
331                    tokio::time::sleep(SLEEP_DURATION).await
332                }
333                x => return x,
334            }
335        }
336    }
337
338    /// Reads a frame using [`read_frame`] and then deserializes the bytes into `D`.
339    ///
340    /// [`read_frame`]: FramedTransport::read_frame
341    pub async fn read_frame_as<D: DeserializeOwned>(&mut self) -> io::Result<Option<D>> {
342        match self.read_frame().await {
343            Ok(Some(frame)) => Ok(Some(utils::deserialize_from_slice(frame.as_item())?)),
344            Ok(None) => Ok(None),
345            Err(x) => Err(x),
346        }
347    }
348
349    /// Writes a `frame` of bytes by using the [`Codec`] tied to this transport.
350    ///
351    /// This is accomplished by continually calling the inner transport's `try_write`. If 0 is
352    /// returned from a call to `try_write`, this will fail with [`ErrorKind::WriteZero`].
353    ///
354    /// This call may return an error with [`ErrorKind::WouldBlock`] in the case that the transport
355    /// is not ready to write data or has not written the entire frame before waiting.
356    ///
357    /// [`ErrorKind::WriteZero`]: io::ErrorKind::WriteZero
358    /// [`ErrorKind::WouldBlock`]: io::ErrorKind::WouldBlock
359    pub fn try_write_frame<'a, F>(&mut self, frame: F) -> io::Result<()>
360    where
361        F: TryInto<Frame<'a>>,
362        F::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
363    {
364        // Grab the frame to send
365        let frame = frame
366            .try_into()
367            .map_err(|x| io::Error::new(io::ErrorKind::InvalidInput, x))?;
368
369        // Encode the frame and store it in our outgoing queue
370        self.codec
371            .encode(frame.as_borrowed())?
372            .write(&mut self.outgoing);
373
374        // Update tracking stats and more of backup if frame is nonempty
375        if frame.is_nonempty() {
376            // Once the frame enters our queue, we count it as written, even if it isn't fully flushed
377            self.backup.increment_sent_cnt();
378
379            // Then we store the raw frame (non-encoded) for the future in case we need to retry
380            // sending it later (possibly with a different codec)
381            self.backup.push_frame(frame);
382        }
383
384        // Attempt to write everything in our queue
385        self.try_flush()?;
386
387        Ok(())
388    }
389
390    /// Serializes `value` into bytes and passes them to [`try_write_frame`].
391    ///
392    /// [`try_write_frame`]: FramedTransport::try_write_frame
393    pub fn try_write_frame_for<D: Serialize>(&mut self, value: &D) -> io::Result<()> {
394        let data = utils::serialize_to_vec(value)?;
395        self.try_write_frame(data)
396    }
397
398    /// Invokes [`try_write_frame`] followed by a continuous calls to [`try_flush`] until a frame
399    /// is successfully written, an error is encountered that is not [`ErrorKind::WouldBlock`], or
400    /// the underlying transport has closed.
401    ///
402    /// [`try_write_frame`]: FramedTransport::try_write_frame
403    /// [`try_flush`]: FramedTransport::try_flush
404    /// [`ErrorKind::WouldBlock`]: io::ErrorKind::WouldBlock
405    pub async fn write_frame<'a, F>(&mut self, frame: F) -> io::Result<()>
406    where
407        F: TryInto<Frame<'a>>,
408        F::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
409    {
410        self.writeable().await?;
411
412        match self.try_write_frame(frame) {
413            // Would block, so continually try to flush until good to go
414            Err(x) if x.kind() == io::ErrorKind::WouldBlock => loop {
415                self.writeable().await?;
416                match self.try_flush() {
417                    Err(x) if x.kind() == io::ErrorKind::WouldBlock => {
418                        // NOTE: We sleep for a little bit before trying again to avoid pegging CPU
419                        tokio::time::sleep(SLEEP_DURATION).await
420                    }
421                    Err(x) => return Err(x),
422                    Ok(_) => return Ok(()),
423                }
424            },
425
426            // Already fully succeeded or failed
427            x => x,
428        }
429    }
430
431    /// Serializes `value` into bytes and passes them to [`write_frame`].
432    ///
433    /// [`write_frame`]: FramedTransport::write_frame
434    pub async fn write_frame_for<D: Serialize>(&mut self, value: &D) -> io::Result<()> {
435        let data = utils::serialize_to_vec(value)?;
436        self.write_frame(data).await
437    }
438
439    /// Executes the async function while the [`Backup`] of this transport is frozen.
440    pub async fn do_frozen<F, X>(&mut self, mut f: F) -> io::Result<()>
441    where
442        F: FnMut(&mut Self) -> X,
443        X: Future<Output = io::Result<()>>,
444    {
445        let is_frozen = self.backup.is_frozen();
446        self.backup.freeze();
447        let result = f(self).await;
448        self.backup.set_frozen(is_frozen);
449        result
450    }
451
452    /// Places the transport in **synchronize mode** where it communicates with the other side how
453    /// many frames have been sent and received. From there, any frames not received by the other
454    /// side are sent again and then this transport waits for any missing frames that it did not
455    /// receive from the other side.
456    ///
457    /// ### Note
458    ///
459    /// This will clear the internal incoming and outgoing buffers, so any frame that was in
460    /// transit in either direction will be dropped.
461    pub async fn synchronize(&mut self) -> io::Result<()> {
462        async fn synchronize_impl<T: Transport>(
463            this: &mut FramedTransport<T>,
464            backup: &mut Backup,
465        ) -> io::Result<()> {
466            type Stats = (u64, u64, u64);
467
468            // Stats in the form of (sent, received, available)
469            let sent_cnt: u64 = backup.sent_cnt();
470            let received_cnt: u64 = backup.received_cnt();
471            let available_cnt: u64 = backup
472                .frame_cnt()
473                .try_into()
474                .expect("Cannot case usize to u64");
475
476            // Clear our internal buffers
477            this.clear();
478
479            // Communicate frame counters with other side so we can determine how many frames to send
480            // and how many to receive. Wait until we get the stats from the other side, and then send
481            // over any missing frames.
482            trace!(
483                "Stats: sent = {sent_cnt}, received = {received_cnt}, available = {available_cnt}"
484            );
485            this.write_frame_for(&(sent_cnt, received_cnt, available_cnt))
486                .await?;
487            let (other_sent_cnt, other_received_cnt, other_available_cnt) =
488                this.read_frame_as::<Stats>().await?.ok_or_else(|| {
489                    io::Error::new(
490                        io::ErrorKind::UnexpectedEof,
491                        "Transport terminated before getting replay stats",
492                    )
493                })?;
494            trace!("Other stats: sent = {other_sent_cnt}, received = {other_received_cnt}, available = {other_available_cnt}");
495
496            // Determine how many frames we need to resend. This will either be (sent - received) or
497            // available frames, whichever is smaller.
498            let resend_cnt = std::cmp::min(
499                if sent_cnt > other_received_cnt {
500                    sent_cnt - other_received_cnt
501                } else {
502                    0
503                },
504                available_cnt,
505            );
506
507            // Determine how many frames we expect to receive. This will either be (received - sent) or
508            // available frames, whichever is smaller.
509            let expected_cnt = std::cmp::min(
510                if received_cnt < other_sent_cnt {
511                    other_sent_cnt - received_cnt
512                } else {
513                    0
514                },
515                other_available_cnt,
516            );
517
518            // Send all missing frames, removing any frames that we know have been received
519            trace!("Reducing internal replay frames to {resend_cnt}");
520            backup.truncate_front(resend_cnt.try_into().expect("Cannot cast usize to u64"));
521
522            debug!("Sending {resend_cnt} frames");
523            for frame in backup.frames() {
524                this.try_write_frame(frame.as_borrowed())?;
525            }
526            this.flush().await?;
527
528            // Receive all expected frames, placing their contents into our incoming queue
529            //
530            // NOTE: We do not increment our counter as this is done during `try_read_frame`, even
531            //       when the frame comes from our internal queue. To avoid duplicating the increment,
532            //       we do not increment the counter here.
533            debug!("Waiting for {expected_cnt} frames");
534            for i in 0..expected_cnt {
535                let frame = this.read_frame().await?.ok_or_else(|| {
536                    io::Error::new(
537                        io::ErrorKind::UnexpectedEof,
538                        format!(
539                            "Transport terminated before getting frame {}/{expected_cnt}",
540                            i + 1
541                        ),
542                    )
543                })?;
544
545                // Encode our frame and write it to be queued in our incoming data
546                // NOTE: We have to do encoding here as incoming bytes are expected to be encoded
547                this.codec.encode(frame)?.write(&mut this.incoming);
548            }
549
550            // Catch up our read count as we can have the case where the other side has a higher
551            // count than frames sent if some frames were fully dropped due to size limits
552            if backup.received_cnt() != other_sent_cnt {
553                warn!(
554                    "Backup received count ({}) != other sent count ({}), so resetting to match",
555                    backup.received_cnt(),
556                    other_sent_cnt
557                );
558                backup.set_received_cnt(other_sent_cnt);
559            }
560
561            Ok(())
562        }
563
564        // Swap out our backup so we don't mutate it from synchronization efforts
565        let mut backup = std::mem::take(&mut self.backup);
566
567        // Perform our operation, but don't return immediately so we can restore our backup
568        let result = synchronize_impl(self, &mut backup).await;
569
570        // Reset our backup to the real version
571        self.backup = backup;
572
573        result
574    }
575
576    /// Shorthand for creating a [`FramedTransport`] with a [`PlainCodec`] and then immediately
577    /// performing a [`client_handshake`], returning the updated [`FramedTransport`] on success.
578    ///
579    /// [`client_handshake`]: FramedTransport::client_handshake
580    #[inline]
581    pub async fn from_client_handshake(transport: T) -> io::Result<Self> {
582        let mut transport = Self::plain(transport);
583        transport.client_handshake().await?;
584        Ok(transport)
585    }
586
587    /// Perform the client-side of a handshake. See [`handshake`] for more details.
588    ///
589    /// [`handshake`]: FramedTransport::handshake
590    pub async fn client_handshake(&mut self) -> io::Result<()> {
591        self.handshake(Handshake::client()).await
592    }
593
594    /// Shorthand for creating a [`FramedTransport`] with a [`PlainCodec`] and then immediately
595    /// performing a [`server_handshake`], returning the updated [`FramedTransport`] on success.
596    ///
597    /// [`client_handshake`]: FramedTransport::client_handshake
598    #[inline]
599    pub async fn from_server_handshake(transport: T) -> io::Result<Self> {
600        let mut transport = Self::plain(transport);
601        transport.server_handshake().await?;
602        Ok(transport)
603    }
604
605    /// Perform the server-side of a handshake. See [`handshake`] for more details.
606    ///
607    /// [`handshake`]: FramedTransport::handshake
608    pub async fn server_handshake(&mut self) -> io::Result<()> {
609        self.handshake(Handshake::server()).await
610    }
611
612    /// Performs a handshake in order to establish a new codec to use between this transport and
613    /// the other side. The parameter `handshake` defines how the transport will handle the
614    /// handshake with `Client` being used to pick the compression and encryption used while
615    /// `Server` defines what the choices are for compression and encryption.
616    ///
617    /// This will reset the framed transport's codec to [`PlainCodec`] in order to communicate
618    /// which compression and encryption to use. Upon selecting an encryption type, a shared secret
619    /// key will be derived on both sides and used to establish the [`EncryptionCodec`], which in
620    /// combination with the [`CompressionCodec`] (if any) will replace this transport's codec.
621    ///
622    /// ### Client
623    ///
624    /// 1. Wait for options from server
625    /// 2. Send to server a compression and encryption choice
626    /// 3. Configure framed transport using selected choices
627    /// 4. Invoke on_handshake function
628    ///
629    /// ### Server
630    ///
631    /// 1. Send options to client
632    /// 2. Receive choices from client
633    /// 3. Configure framed transport using client's choices
634    /// 4. Invoke on_handshake function
635    ///
636    /// ### Failure
637    ///
638    /// The handshake will fail in several cases:
639    ///
640    /// * If any frame during the handshake fails to be serialized
641    /// * If any unexpected frame is received during the handshake
642    /// * If using encryption and unable to derive a shared secret key
643    ///
644    /// If a failure happens, the codec will be reset to what it was prior to the handshake
645    /// request, and all internal buffers will be cleared to avoid corruption.
646    ///
647    pub async fn handshake(&mut self, handshake: Handshake) -> io::Result<()> {
648        // Place transport in plain text communication mode for start of handshake, and clear any
649        // data that is lingering within internal buffers
650        //
651        // NOTE: We grab the old codec in case we encounter an error and need to reset it
652        let old_codec = std::mem::replace(&mut self.codec, Box::new(PlainCodec::new()));
653        self.clear();
654
655        // Swap out our backup so we don't mutate it from synchronization efforts
656        let backup = std::mem::take(&mut self.backup);
657
658        // Transform the transport's codec to abide by the choice. In the case of an error, we
659        // reset the codec back to what it was prior to attempting the handshake and clear the
660        // internal buffers as they may be corrupt.
661        match self.handshake_impl(handshake).await {
662            Ok(codec) => {
663                self.set_codec(codec);
664                self.backup = backup;
665                Ok(())
666            }
667            Err(x) => {
668                self.set_codec(old_codec);
669                self.clear();
670                self.backup = backup;
671                Err(x)
672            }
673        }
674    }
675
676    async fn handshake_impl(&mut self, handshake: Handshake) -> io::Result<BoxedCodec> {
677        #[derive(Debug, Serialize, Deserialize)]
678        struct Choice {
679            compression_level: Option<CompressionLevel>,
680            compression_type: Option<CompressionType>,
681            encryption_type: Option<EncryptionType>,
682        }
683
684        #[derive(Debug, Serialize, Deserialize)]
685        struct Options {
686            compression_types: Vec<CompressionType>,
687            encryption_types: Vec<EncryptionType>,
688        }
689
690        // Define a label to distinguish log output for client and server
691        let log_label = if handshake.is_client() {
692            "Handshake | Client"
693        } else {
694            "Handshake | Server"
695        };
696
697        // Determine compression and encryption to apply to framed transport
698        let choice = match handshake {
699            Handshake::Client {
700                preferred_compression_type,
701                preferred_compression_level,
702                preferred_encryption_type,
703            } => {
704                // Receive options from the server and pick one
705                debug!("[{log_label}] Waiting on options");
706                let options = self.read_frame_as::<Options>().await?.ok_or_else(|| {
707                    io::Error::new(
708                        io::ErrorKind::UnexpectedEof,
709                        "Transport closed early while waiting for options",
710                    )
711                })?;
712
713                // Choose a compression and encryption option from the options
714                debug!("[{log_label}] Selecting from options: {options:?}");
715                let choice = Choice {
716                    // Use preferred compression if available, otherwise default to no compression
717                    // to avoid choosing something poor
718                    compression_type: preferred_compression_type
719                        .filter(|ty| options.compression_types.contains(ty)),
720
721                    // Use preferred compression level, otherwise allowing the server to pick
722                    compression_level: preferred_compression_level,
723
724                    // Use preferred encryption, otherwise pick first non-unknown encryption type
725                    // that is available instead
726                    encryption_type: preferred_encryption_type
727                        .filter(|ty| options.encryption_types.contains(ty))
728                        .or_else(|| {
729                            options
730                                .encryption_types
731                                .iter()
732                                .find(|ty| !ty.is_unknown())
733                                .copied()
734                        }),
735                };
736
737                // Report back to the server the choice
738                debug!("[{log_label}] Reporting choice: {choice:?}");
739                self.write_frame_for(&choice).await?;
740
741                choice
742            }
743            Handshake::Server {
744                compression_types,
745                encryption_types,
746            } => {
747                let options = Options {
748                    compression_types: compression_types.to_vec(),
749                    encryption_types: encryption_types.to_vec(),
750                };
751
752                // Send options to the client
753                debug!("[{log_label}] Sending options: {options:?}");
754                self.write_frame_for(&options).await?;
755
756                // Get client's response with selected compression and encryption
757                debug!("[{log_label}] Waiting on choice");
758                self.read_frame_as::<Choice>().await?.ok_or_else(|| {
759                    io::Error::new(
760                        io::ErrorKind::UnexpectedEof,
761                        "Transport closed early while waiting for choice",
762                    )
763                })?
764            }
765        };
766
767        debug!("[{log_label}] Building compression & encryption codecs based on {choice:?}");
768        let compression_level = choice.compression_level.unwrap_or_default();
769
770        // Acquire a codec for the compression type
771        let compression_codec = choice
772            .compression_type
773            .map(|ty| ty.new_codec(compression_level))
774            .transpose()?;
775
776        // In the case that we are using encryption, we derive a shared secret key to use with the
777        // encryption type
778        let encryption_codec = match choice.encryption_type {
779            // Fail early if we got an unknown encryption type
780            Some(EncryptionType::Unknown) => {
781                return Err(io::Error::new(
782                    io::ErrorKind::InvalidInput,
783                    "Unknown compression type",
784                ))
785            }
786            Some(ty) => {
787                let key = self.exchange_keys_impl(log_label).await?;
788                Some(ty.new_codec(key.unprotected_as_bytes())?)
789            }
790            None => None,
791        };
792
793        // Bundle our compression and encryption codecs into a single, chained codec
794        trace!("[{log_label}] Bundling codecs");
795        let codec: BoxedCodec = match (compression_codec, encryption_codec) {
796            // If we have both encryption and compression, do the encryption first and then
797            // compress in order to get smallest result
798            (Some(c), Some(e)) => Box::new(ChainCodec::new(e, c)),
799
800            // If we just have compression, pass along the compression codec
801            (Some(c), None) => Box::new(c),
802
803            // If we just have encryption, pass along the encryption codec
804            (None, Some(e)) => Box::new(e),
805
806            // If we have neither compression nor encryption, use a plaintext codec
807            (None, None) => Box::new(PlainCodec::new()),
808        };
809
810        Ok(codec)
811    }
812
813    /// Places the transport into key-exchange mode where it attempts to derive a shared secret key
814    /// with the other transport.
815    pub async fn exchange_keys(&mut self) -> io::Result<SecretKey32> {
816        self.exchange_keys_impl("").await
817    }
818
819    async fn exchange_keys_impl(&mut self, label: &str) -> io::Result<SecretKey32> {
820        let log_label = if label.is_empty() {
821            String::new()
822        } else {
823            format!("[{label}] ")
824        };
825
826        #[derive(Serialize, Deserialize)]
827        struct KeyExchangeData {
828            /// Bytes of the public key
829            #[serde(with = "serde_bytes")]
830            public_key: PublicKeyBytes,
831
832            /// Randomly generated salt
833            #[serde(with = "serde_bytes")]
834            salt: Salt,
835        }
836
837        debug!("{log_label}Exchanging public key and salt");
838        let exchange = KeyExchange::default();
839        self.write_frame_for(&KeyExchangeData {
840            public_key: exchange.pk_bytes(),
841            salt: *exchange.salt(),
842        })
843        .await?;
844
845        // TODO: This key only works because it happens to be 32 bytes and our encryption
846        //       also wants a 32-byte key. Once we introduce new encryption algorithms that
847        //       are not using 32-byte keys, the key exchange will need to support deriving
848        //       other length keys.
849        trace!("{log_label}Waiting on public key and salt from other side");
850        let data = self
851            .read_frame_as::<KeyExchangeData>()
852            .await?
853            .ok_or_else(|| {
854                io::Error::new(
855                    io::ErrorKind::UnexpectedEof,
856                    "Transport closed early while waiting for key data",
857                )
858            })?;
859
860        trace!("{log_label}Deriving shared secret key");
861        let key = exchange.derive_shared_secret(data.public_key, data.salt)?;
862        Ok(key)
863    }
864}
865
866#[async_trait]
867impl<T> Reconnectable for FramedTransport<T>
868where
869    T: Transport,
870{
871    async fn reconnect(&mut self) -> io::Result<()> {
872        Reconnectable::reconnect(&mut self.inner).await
873    }
874}
875
876impl FramedTransport<InmemoryTransport> {
877    /// Produces a pair of inmemory transports that are connected to each other using a
878    /// [`PlainCodec`].
879    ///
880    /// Sets the buffer for message passing for each underlying transport to the given buffer size.
881    pub fn pair(
882        buffer: usize,
883    ) -> (
884        FramedTransport<InmemoryTransport>,
885        FramedTransport<InmemoryTransport>,
886    ) {
887        let (a, b) = InmemoryTransport::pair(buffer);
888        let a = FramedTransport::new(a, Box::new(PlainCodec::new()));
889        let b = FramedTransport::new(b, Box::new(PlainCodec::new()));
890        (a, b)
891    }
892
893    /// Links the underlying transports together using [`InmemoryTransport::link`].
894    pub fn link(&mut self, other: &mut Self, buffer: usize) {
895        self.inner.link(&mut other.inner, buffer)
896    }
897}
898
899#[cfg(test)]
900impl FramedTransport<InmemoryTransport> {
901    /// Generates a test pair with default capacity
902    pub fn test_pair(
903        buffer: usize,
904    ) -> (
905        FramedTransport<InmemoryTransport>,
906        FramedTransport<InmemoryTransport>,
907    ) {
908        Self::pair(buffer)
909    }
910}
911
912#[cfg(test)]
913mod tests {
914    use bytes::BufMut;
915    use test_log::test;
916
917    use super::*;
918    use crate::common::TestTransport;
919
920    /// Codec that always succeeds without altering the frame
921    #[derive(Clone, Debug, PartialEq, Eq)]
922    struct OkCodec;
923
924    impl Codec for OkCodec {
925        fn encode<'a>(&mut self, frame: Frame<'a>) -> io::Result<Frame<'a>> {
926            Ok(frame)
927        }
928
929        fn decode<'a>(&mut self, frame: Frame<'a>) -> io::Result<Frame<'a>> {
930            Ok(frame)
931        }
932    }
933
934    /// Codec that always fails
935    #[derive(Clone, Debug, PartialEq, Eq)]
936    struct ErrCodec;
937
938    impl Codec for ErrCodec {
939        fn encode<'a>(&mut self, _frame: Frame<'a>) -> io::Result<Frame<'a>> {
940            Err(io::Error::from(io::ErrorKind::Other))
941        }
942
943        fn decode<'a>(&mut self, _frame: Frame<'a>) -> io::Result<Frame<'a>> {
944            Err(io::Error::from(io::ErrorKind::Other))
945        }
946    }
947
948    // Hardcoded custom codec so we can verify it works differently than plain codec
949    #[derive(Clone)]
950    struct CustomCodec;
951
952    impl Codec for CustomCodec {
953        fn encode<'a>(&mut self, _: Frame<'a>) -> io::Result<Frame<'a>> {
954            Ok(Frame::new(b"encode"))
955        }
956
957        fn decode<'a>(&mut self, _: Frame<'a>) -> io::Result<Frame<'a>> {
958            Ok(Frame::new(b"decode"))
959        }
960    }
961
962    type SimulateTryReadFn = Box<dyn Fn(&mut [u8]) -> io::Result<usize> + Send + Sync>;
963
964    /// Simulate calls to try_read by feeding back `data` in `step` increments, triggering a block
965    /// if `block_on` returns true where `block_on` is provided a counter value that is incremented
966    /// every time the simulated `try_read` function is called
967    ///
968    /// NOTE: This will inject the frame len in front of the provided data to properly simulate
969    ///       receiving a frame of data
970    fn simulate_try_read(
971        frames: Vec<Frame>,
972        step: usize,
973        block_on: impl Fn(usize) -> bool + Send + Sync + 'static,
974    ) -> SimulateTryReadFn {
975        use std::sync::atomic::{AtomicUsize, Ordering};
976
977        // Stuff all of our frames into a single byte collection
978        let data = {
979            let mut buf = BytesMut::new();
980
981            for frame in frames {
982                frame.write(&mut buf);
983            }
984
985            buf.to_vec()
986        };
987
988        let idx = AtomicUsize::new(0);
989        let cnt = AtomicUsize::new(0);
990
991        Box::new(move |buf| {
992            if block_on(cnt.fetch_add(1, Ordering::Relaxed)) {
993                return Err(io::Error::from(io::ErrorKind::WouldBlock));
994            }
995
996            let start = idx.fetch_add(step, Ordering::Relaxed);
997            let end = start + step;
998            let end = if end > data.len() { data.len() } else { end };
999            let len = if start > end { 0 } else { end - start };
1000
1001            buf[..len].copy_from_slice(&data[start..end]);
1002            Ok(len)
1003        })
1004    }
1005
1006    #[test]
1007    fn try_read_frame_should_return_would_block_if_fails_to_read_frame_before_blocking() {
1008        // Should fail if immediately blocks
1009        let mut transport = FramedTransport::new(
1010            TestTransport {
1011                f_try_read: Box::new(|_| Err(io::Error::from(io::ErrorKind::WouldBlock))),
1012                f_ready: Box::new(|_| Ok(Ready::READABLE)),
1013                ..Default::default()
1014            },
1015            Box::new(OkCodec),
1016        );
1017        assert_eq!(
1018            transport.try_read_frame().unwrap_err().kind(),
1019            io::ErrorKind::WouldBlock
1020        );
1021
1022        // Should fail if not read enough bytes before blocking
1023        let mut transport = FramedTransport::new(
1024            TestTransport {
1025                f_try_read: simulate_try_read(vec![Frame::new(b"some data")], 1, |cnt| cnt == 1),
1026                f_ready: Box::new(|_| Ok(Ready::READABLE)),
1027                ..Default::default()
1028            },
1029            Box::new(OkCodec),
1030        );
1031        assert_eq!(
1032            transport.try_read_frame().unwrap_err().kind(),
1033            io::ErrorKind::WouldBlock
1034        );
1035    }
1036
1037    #[test]
1038    fn try_read_frame_should_return_error_if_encountered_error_with_reading_bytes() {
1039        let mut transport = FramedTransport::new(
1040            TestTransport {
1041                f_try_read: Box::new(|_| Err(io::Error::from(io::ErrorKind::NotConnected))),
1042                f_ready: Box::new(|_| Ok(Ready::READABLE)),
1043                ..Default::default()
1044            },
1045            Box::new(OkCodec),
1046        );
1047        assert_eq!(
1048            transport.try_read_frame().unwrap_err().kind(),
1049            io::ErrorKind::NotConnected
1050        );
1051    }
1052
1053    #[test]
1054    fn try_read_frame_should_return_error_if_encountered_error_during_decode() {
1055        let mut transport = FramedTransport::new(
1056            TestTransport {
1057                f_try_read: simulate_try_read(vec![Frame::new(b"some data")], 1, |_| false),
1058                f_ready: Box::new(|_| Ok(Ready::READABLE)),
1059                ..Default::default()
1060            },
1061            Box::new(ErrCodec),
1062        );
1063        assert_eq!(
1064            transport.try_read_frame().unwrap_err().kind(),
1065            io::ErrorKind::Other
1066        );
1067    }
1068
1069    #[test]
1070    fn try_read_frame_should_return_next_available_frame() {
1071        let data = {
1072            let mut data = BytesMut::new();
1073            Frame::new(b"hello world").write(&mut data);
1074            data.freeze()
1075        };
1076
1077        let mut transport = FramedTransport::new(
1078            TestTransport {
1079                f_try_read: Box::new(move |buf| {
1080                    buf[..data.len()].copy_from_slice(data.as_ref());
1081                    Ok(data.len())
1082                }),
1083                f_ready: Box::new(|_| Ok(Ready::READABLE)),
1084                ..Default::default()
1085            },
1086            Box::new(OkCodec),
1087        );
1088        assert_eq!(transport.try_read_frame().unwrap().unwrap(), b"hello world");
1089    }
1090
1091    #[test]
1092    fn try_read_frame_should_return_next_available_frame_if_already_in_incoming_buffer() {
1093        // Store two frames in our data to transmit
1094        let data = {
1095            let mut data = BytesMut::new();
1096            Frame::new(b"hello world").write(&mut data);
1097            Frame::new(b"hello again").write(&mut data);
1098            data.freeze()
1099        };
1100
1101        // Configure transport to return both frames in single read such that we have another
1102        // complete frame to parse (in the case that an underlying try_read would block, but we had
1103        // data available before that)
1104        let mut transport = FramedTransport::new(
1105            TestTransport {
1106                f_try_read: Box::new(move |buf| {
1107                    static mut CNT: usize = 0;
1108                    unsafe {
1109                        CNT += 1;
1110                        if CNT == 2 {
1111                            Err(io::Error::from(io::ErrorKind::WouldBlock))
1112                        } else {
1113                            let n = data.len();
1114                            buf[..data.len()].copy_from_slice(data.as_ref());
1115                            Ok(n)
1116                        }
1117                    }
1118                }),
1119                f_ready: Box::new(|_| Ok(Ready::READABLE)),
1120                ..Default::default()
1121            },
1122            Box::new(OkCodec),
1123        );
1124
1125        // Read first frame
1126        assert_eq!(transport.try_read_frame().unwrap().unwrap(), b"hello world");
1127
1128        // Read second frame
1129        assert_eq!(transport.try_read_frame().unwrap().unwrap(), b"hello again");
1130    }
1131
1132    #[test]
1133    fn try_read_frame_should_keep_reading_until_a_frame_is_found() {
1134        const STEP_SIZE: usize = Frame::HEADER_SIZE + 7;
1135
1136        let mut transport = FramedTransport::new(
1137            TestTransport {
1138                f_try_read: simulate_try_read(
1139                    vec![Frame::new(b"hello world"), Frame::new(b"test hello")],
1140                    STEP_SIZE,
1141                    |_| false,
1142                ),
1143                f_ready: Box::new(|_| Ok(Ready::READABLE)),
1144                ..Default::default()
1145            },
1146            Box::new(OkCodec),
1147        );
1148        assert_eq!(transport.try_read_frame().unwrap().unwrap(), b"hello world");
1149
1150        // Should have leftover bytes from next frame
1151        // where len = 10, "tes"
1152        assert_eq!(
1153            transport.incoming.to_vec(),
1154            [0, 0, 0, 0, 0, 0, 0, 10, b't', b'e', b's']
1155        );
1156    }
1157
1158    #[test]
1159    fn try_write_frame_should_return_would_block_if_fails_to_write_frame_before_blocking() {
1160        let mut transport = FramedTransport::new(
1161            TestTransport {
1162                f_try_write: Box::new(|_| Err(io::Error::from(io::ErrorKind::WouldBlock))),
1163                f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
1164                ..Default::default()
1165            },
1166            Box::new(OkCodec),
1167        );
1168
1169        // First call will only write part of the frame and then return WouldBlock
1170        assert_eq!(
1171            transport
1172                .try_write_frame(b"hello world")
1173                .unwrap_err()
1174                .kind(),
1175            io::ErrorKind::WouldBlock
1176        );
1177    }
1178
1179    #[test]
1180    fn try_write_frame_should_return_error_if_encountered_error_with_writing_bytes() {
1181        let mut transport = FramedTransport::new(
1182            TestTransport {
1183                f_try_write: Box::new(|_| Err(io::Error::from(io::ErrorKind::NotConnected))),
1184                f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
1185                ..Default::default()
1186            },
1187            Box::new(OkCodec),
1188        );
1189        assert_eq!(
1190            transport
1191                .try_write_frame(b"hello world")
1192                .unwrap_err()
1193                .kind(),
1194            io::ErrorKind::NotConnected
1195        );
1196    }
1197
1198    #[test]
1199    fn try_write_frame_should_return_error_if_encountered_error_during_encode() {
1200        let mut transport = FramedTransport::new(
1201            TestTransport {
1202                f_try_write: Box::new(|buf| Ok(buf.len())),
1203                f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
1204                ..Default::default()
1205            },
1206            Box::new(ErrCodec),
1207        );
1208        assert_eq!(
1209            transport
1210                .try_write_frame(b"hello world")
1211                .unwrap_err()
1212                .kind(),
1213            io::ErrorKind::Other
1214        );
1215    }
1216
1217    #[test]
1218    fn try_write_frame_should_write_entire_frame_if_possible() {
1219        let (tx, rx) = std::sync::mpsc::sync_channel(1);
1220        let mut transport = FramedTransport::new(
1221            TestTransport {
1222                f_try_write: Box::new(move |buf| {
1223                    let len = buf.len();
1224                    tx.send(buf.to_vec()).unwrap();
1225                    Ok(len)
1226                }),
1227                f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
1228                ..Default::default()
1229            },
1230            Box::new(OkCodec),
1231        );
1232
1233        transport.try_write_frame(b"hello world").unwrap();
1234
1235        // Transmitted data should be encoded using the framed transport's codec
1236        assert_eq!(
1237            rx.try_recv().unwrap(),
1238            [11u64.to_be_bytes().as_slice(), b"hello world".as_slice()].concat()
1239        );
1240    }
1241
1242    #[test]
1243    fn try_write_frame_should_write_any_prior_queued_bytes_before_writing_next_frame() {
1244        const STEP_SIZE: usize = Frame::HEADER_SIZE + 5;
1245        let (tx, rx) = std::sync::mpsc::sync_channel(10);
1246        let mut transport = FramedTransport::new(
1247            TestTransport {
1248                f_try_write: Box::new(move |buf| {
1249                    static mut CNT: usize = 0;
1250                    unsafe {
1251                        CNT += 1;
1252                        if CNT == 2 {
1253                            Err(io::Error::from(io::ErrorKind::WouldBlock))
1254                        } else {
1255                            let len = std::cmp::min(STEP_SIZE, buf.len());
1256                            tx.send(buf[..len].to_vec()).unwrap();
1257                            Ok(len)
1258                        }
1259                    }
1260                }),
1261                f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
1262                ..Default::default()
1263            },
1264            Box::new(OkCodec),
1265        );
1266
1267        // First call will only write part of the frame and then return WouldBlock
1268        assert_eq!(
1269            transport
1270                .try_write_frame(b"hello world")
1271                .unwrap_err()
1272                .kind(),
1273            io::ErrorKind::WouldBlock
1274        );
1275
1276        // Transmitted data should be encoded using the framed transport's codec
1277        assert_eq!(
1278            rx.try_recv().unwrap(),
1279            [11u64.to_be_bytes().as_slice(), b"hello".as_slice()].concat()
1280        );
1281        assert_eq!(
1282            rx.try_recv().unwrap_err(),
1283            std::sync::mpsc::TryRecvError::Empty
1284        );
1285
1286        // Next call will keep writing successfully until done
1287        transport.try_write_frame(b"test").unwrap();
1288        assert_eq!(
1289            rx.try_recv().unwrap(),
1290            [b' ', b'w', b'o', b'r', b'l', b'd', 0, 0, 0, 0, 0, 0, 0]
1291        );
1292        assert_eq!(rx.try_recv().unwrap(), [4, b't', b'e', b's', b't']);
1293        assert_eq!(
1294            rx.try_recv().unwrap_err(),
1295            std::sync::mpsc::TryRecvError::Empty
1296        );
1297    }
1298
1299    #[test]
1300    fn try_flush_should_return_error_if_try_write_fails() {
1301        let mut transport = FramedTransport::new(
1302            TestTransport {
1303                f_try_write: Box::new(|_| Err(io::Error::from(io::ErrorKind::NotConnected))),
1304                f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
1305                ..Default::default()
1306            },
1307            Box::new(OkCodec),
1308        );
1309
1310        // Set our outgoing buffer to flush
1311        transport.outgoing.put_slice(b"hello world");
1312
1313        // Perform flush and verify error happens
1314        assert_eq!(
1315            transport.try_flush().unwrap_err().kind(),
1316            io::ErrorKind::NotConnected
1317        );
1318    }
1319
1320    #[test]
1321    fn try_flush_should_return_error_if_try_write_returns_0_bytes_written() {
1322        let mut transport = FramedTransport::new(
1323            TestTransport {
1324                f_try_write: Box::new(|_| Ok(0)),
1325                f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
1326                ..Default::default()
1327            },
1328            Box::new(OkCodec),
1329        );
1330
1331        // Set our outgoing buffer to flush
1332        transport.outgoing.put_slice(b"hello world");
1333
1334        // Perform flush and verify error happens
1335        assert_eq!(
1336            transport.try_flush().unwrap_err().kind(),
1337            io::ErrorKind::WriteZero
1338        );
1339    }
1340
1341    #[test]
1342    fn try_flush_should_be_noop_if_nothing_to_flush() {
1343        let mut transport = FramedTransport::new(
1344            TestTransport {
1345                f_try_write: Box::new(|_| Err(io::Error::from(io::ErrorKind::NotConnected))),
1346                f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
1347                ..Default::default()
1348            },
1349            Box::new(OkCodec),
1350        );
1351
1352        // Perform flush and verify nothing happens
1353        transport.try_flush().unwrap();
1354    }
1355
1356    #[test]
1357    fn try_flush_should_continually_call_try_write_until_outgoing_buffer_is_empty() {
1358        const STEP_SIZE: usize = 5;
1359        let (tx, rx) = std::sync::mpsc::sync_channel(10);
1360        let mut transport = FramedTransport::new(
1361            TestTransport {
1362                f_try_write: Box::new(move |buf| {
1363                    let len = std::cmp::min(STEP_SIZE, buf.len());
1364                    tx.send(buf[..len].to_vec()).unwrap();
1365                    Ok(len)
1366                }),
1367                f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
1368                ..Default::default()
1369            },
1370            Box::new(OkCodec),
1371        );
1372
1373        // Set our outgoing buffer to flush
1374        transport.outgoing.put_slice(b"hello world");
1375
1376        // Perform flush
1377        transport.try_flush().unwrap();
1378
1379        // Verify outgoing data flushed with N calls to try_write
1380        assert_eq!(rx.try_recv().unwrap(), b"hello".as_slice());
1381        assert_eq!(rx.try_recv().unwrap(), b" worl".as_slice());
1382        assert_eq!(rx.try_recv().unwrap(), b"d".as_slice());
1383        assert_eq!(
1384            rx.try_recv().unwrap_err(),
1385            std::sync::mpsc::TryRecvError::Empty
1386        );
1387    }
1388
1389    #[inline]
1390    async fn test_synchronize_stats(
1391        transport: &mut FramedTransport<InmemoryTransport>,
1392        sent_cnt: u64,
1393        received_cnt: u64,
1394        available_cnt: u64,
1395        expected_sent_cnt: u64,
1396        expected_received_cnt: u64,
1397        expected_available_cnt: u64,
1398    ) {
1399        // From the other side, claim that we have received 2 frames
1400        // (sent, received, available)
1401        transport
1402            .write_frame_for(&(sent_cnt, received_cnt, available_cnt))
1403            .await
1404            .unwrap();
1405
1406        // Receive stats from the other side
1407        let (sent, received, available) = transport
1408            .read_frame_as::<(u64, u64, u64)>()
1409            .await
1410            .unwrap()
1411            .unwrap();
1412        assert_eq!(sent, expected_sent_cnt, "Wrong sent cnt");
1413        assert_eq!(received, expected_received_cnt, "Wrong received cnt");
1414        assert_eq!(available, expected_available_cnt, "Wrong available cnt");
1415    }
1416
1417    #[test(tokio::test)]
1418    async fn synchronize_should_resend_no_frames_if_other_side_claims_it_has_more_than_us() {
1419        let (mut t1, mut t2) = FramedTransport::pair(100);
1420
1421        // Configure the backup such that we have sent one frame
1422        t2.backup.push_frame(Frame::new(b"hello world"));
1423        t2.backup.increment_sent_cnt();
1424
1425        // Spawn a separate task to do synchronization simulation so we don't deadlock, and also
1426        // send a frame to indicate when finished so we can know when synchronization is done
1427        // during our test
1428        let _task = tokio::spawn(async move {
1429            t2.synchronize().await.unwrap();
1430            t2.write_frame(Frame::new(b"done")).await.unwrap();
1431            t2
1432        });
1433
1434        // fake     (sent, received, available) = 0, 2, 0
1435        // expected (sent, received, available) = 1, 0, 1
1436        test_synchronize_stats(&mut t1, 0, 2, 0, 1, 0, 1).await;
1437
1438        // Should not receive anything before our done indicator
1439        assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"done");
1440    }
1441
1442    #[test(tokio::test)]
1443    async fn synchronize_should_resend_no_frames_if_none_missing_on_other_side() {
1444        let (mut t1, mut t2) = FramedTransport::pair(100);
1445
1446        // Configure the backup such that we have sent one frame
1447        t2.backup.push_frame(Frame::new(b"hello world"));
1448        t2.backup.increment_sent_cnt();
1449
1450        // Spawn a separate task to do synchronization simulation so we don't deadlock, and also
1451        // send a frame to indicate when finished so we can know when synchronization is done
1452        // during our test
1453        let _task = tokio::spawn(async move {
1454            t2.synchronize().await.unwrap();
1455            t2.write_frame(Frame::new(b"done")).await.unwrap();
1456            t2
1457        });
1458
1459        // fake     (sent, received, available) = 0, 1, 0
1460        // expected (sent, received, available) = 1, 0, 1
1461        test_synchronize_stats(&mut t1, 0, 1, 0, 1, 0, 1).await;
1462
1463        // Should not receive anything before our done indicator
1464        assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"done");
1465    }
1466
1467    #[test(tokio::test)]
1468    async fn synchronize_should_resend_some_frames_if_some_missing_on_other_side() {
1469        let (mut t1, mut t2) = FramedTransport::pair(100);
1470
1471        // Configure the backup such that we have sent two frames
1472        t2.backup.push_frame(Frame::new(b"hello"));
1473        t2.backup.push_frame(Frame::new(b"world"));
1474        t2.backup.increment_sent_cnt();
1475        t2.backup.increment_sent_cnt();
1476
1477        // Spawn a separate task to do synchronization simulation so we don't deadlock, and also
1478        // send a frame to indicate when finished so we can know when synchronization is done
1479        // during our test
1480        let _task = tokio::spawn(async move {
1481            t2.synchronize().await.unwrap();
1482            t2.write_frame(Frame::new(b"done")).await.unwrap();
1483            t2
1484        });
1485
1486        // fake     (sent, received, available) = 0, 1, 0
1487        // expected (sent, received, available) = 2, 0, 2
1488        test_synchronize_stats(&mut t1, 0, 1, 0, 2, 0, 2).await;
1489
1490        // Recieve both frames and then the done indicator
1491        assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"world");
1492        assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"done");
1493    }
1494
1495    #[test(tokio::test)]
1496    async fn synchronize_should_resend_all_frames_if_all_missing_on_other_side() {
1497        let (mut t1, mut t2) = FramedTransport::pair(100);
1498
1499        // Configure the backup such that we have sent two frames
1500        t2.backup.push_frame(Frame::new(b"hello"));
1501        t2.backup.push_frame(Frame::new(b"world"));
1502        t2.backup.increment_sent_cnt();
1503        t2.backup.increment_sent_cnt();
1504
1505        // Spawn a separate task to do synchronization simulation so we don't deadlock, and also
1506        // send a frame to indicate when finished so we can know when synchronization is done
1507        // during our test
1508        let _task = tokio::spawn(async move {
1509            t2.synchronize().await.unwrap();
1510            t2.write_frame(Frame::new(b"done")).await.unwrap();
1511            t2
1512        });
1513
1514        // fake     (sent, received, available) = 0, 0, 0
1515        // expected (sent, received, available) = 2, 0, 2
1516        test_synchronize_stats(&mut t1, 0, 0, 0, 2, 0, 2).await;
1517
1518        // Recieve both frames and then the done indicator
1519        assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"hello");
1520        assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"world");
1521        assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"done");
1522    }
1523
1524    #[test(tokio::test)]
1525    async fn synchronize_should_resend_available_frames_if_more_than_available_missing_on_other_side(
1526    ) {
1527        let (mut t1, mut t2) = FramedTransport::pair(100);
1528
1529        // Configure the backup such that we have sent two frames, and believe that we have
1530        // sent 3 in total, a situation that happens once we reach the peak possible size of
1531        // old frames to store
1532        t2.backup.push_frame(Frame::new(b"hello"));
1533        t2.backup.push_frame(Frame::new(b"world"));
1534        t2.backup.increment_sent_cnt();
1535        t2.backup.increment_sent_cnt();
1536        t2.backup.increment_sent_cnt();
1537
1538        // Spawn a separate task to do synchronization simulation so we don't deadlock, and also
1539        // send a frame to indicate when finished so we can know when synchronization is done
1540        // during our test
1541        let _task = tokio::spawn(async move {
1542            t2.synchronize().await.unwrap();
1543            t2.write_frame(Frame::new(b"done")).await.unwrap();
1544            t2
1545        });
1546
1547        // fake     (sent, received, available) = 0, 0, 0
1548        // expected (sent, received, available) = 3, 0, 2
1549        test_synchronize_stats(&mut t1, 0, 0, 0, 3, 0, 2).await;
1550
1551        // Recieve both frames and then the done indicator
1552        assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"hello");
1553        assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"world");
1554        assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"done");
1555    }
1556
1557    #[test(tokio::test)]
1558    async fn synchronize_should_receive_no_frames_if_other_side_claims_it_has_more_than_us() {
1559        let (mut t1, mut t2) = FramedTransport::pair(100);
1560
1561        // Mark other side as having received a frame
1562        t2.backup.increment_received_cnt();
1563
1564        // Spawn a separate task to do synchronization simulation so we don't deadlock, and also
1565        // send a frame to indicate when finished so we can know when synchronization is done
1566        // during our test
1567        let _task = tokio::spawn(async move {
1568            t2.synchronize().await.unwrap();
1569            t2.write_frame(Frame::new(b"done")).await.unwrap();
1570            t2
1571        });
1572
1573        // fake     (sent, received, available) = 0, 0, 0
1574        // expected (sent, received, available) = 0, 1, 0
1575        test_synchronize_stats(&mut t1, 0, 0, 0, 0, 1, 0).await;
1576
1577        // Recieve the done indicator
1578        assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"done");
1579    }
1580
1581    #[test(tokio::test)]
1582    async fn synchronize_should_receive_no_frames_if_none_missing_from_other_side() {
1583        let (mut t1, mut t2) = FramedTransport::pair(100);
1584
1585        // Mark other side as having received a frame
1586        t2.backup.increment_received_cnt();
1587
1588        // Spawn a separate task to do synchronization simulation so we don't deadlock, and also
1589        // send a frame to indicate when finished so we can know when synchronization is done
1590        // during our test
1591        let _task = tokio::spawn(async move {
1592            t2.synchronize().await.unwrap();
1593            t2.write_frame(Frame::new(b"done")).await.unwrap();
1594            t2
1595        });
1596
1597        // fake     (sent, received, available) = 1, 0, 1
1598        // expected (sent, received, available) = 0, 1, 0
1599        test_synchronize_stats(&mut t1, 1, 0, 1, 0, 1, 0).await;
1600
1601        // Recieve the done indicator
1602        assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"done");
1603    }
1604
1605    #[test(tokio::test)]
1606    async fn synchronize_should_receive_some_frames_if_some_missing_from_other_side() {
1607        let (mut t1, mut t2) = FramedTransport::pair(100);
1608
1609        // Mark other side as having received a frame
1610        t2.backup.increment_received_cnt();
1611
1612        // Spawn a separate task to do synchronization simulation so we don't deadlock, and also
1613        // send a frame to indicate when finished so we can know when synchronization is done
1614        // during our test
1615        let task = tokio::spawn(async move {
1616            t2.synchronize().await.unwrap();
1617            t2.write_frame(Frame::new(b"done")).await.unwrap();
1618            t2
1619        });
1620
1621        // fake     (sent, received, available) = 2, 0, 2
1622        // expected (sent, received, available) = 0, 1, 0
1623        test_synchronize_stats(&mut t1, 2, 0, 2, 0, 1, 0).await;
1624
1625        // Send a frame to fill the gap
1626        t1.write_frame(Frame::new(b"hello")).await.unwrap();
1627
1628        // Recieve the done indicator
1629        assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"done");
1630
1631        // Drop the transport such that the other side will get a definite termination
1632        drop(t1);
1633
1634        // Verify that the frame was captured on the other side
1635        let mut t2 = task.await.unwrap();
1636        assert_eq!(t2.read_frame().await.unwrap().unwrap(), b"hello");
1637        assert_eq!(t2.read_frame().await.unwrap(), None);
1638    }
1639
1640    #[test(tokio::test)]
1641    async fn synchronize_should_receive_all_frames_if_all_missing_from_other_side() {
1642        let (mut t1, mut t2) = FramedTransport::pair(100);
1643
1644        // Spawn a separate task to do synchronization simulation so we don't deadlock, and also
1645        // send a frame to indicate when finished so we can know when synchronization is done
1646        // during our test
1647        let task = tokio::spawn(async move {
1648            t2.synchronize().await.unwrap();
1649            t2.write_frame(Frame::new(b"done")).await.unwrap();
1650            t2
1651        });
1652
1653        // fake     (sent, received, available) = 2, 0, 2
1654        // expected (sent, received, available) = 0, 0, 0
1655        test_synchronize_stats(&mut t1, 2, 0, 2, 0, 0, 0).await;
1656
1657        // Send frames to fill the gap
1658        t1.write_frame(Frame::new(b"hello")).await.unwrap();
1659        t1.write_frame(Frame::new(b"world")).await.unwrap();
1660
1661        // Recieve the done indicator
1662        assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"done");
1663
1664        // Drop the transport such that the other side will get a definite termination
1665        drop(t1);
1666
1667        // Verify that the frame was captured on the other side
1668        let mut t2 = task.await.unwrap();
1669        assert_eq!(t2.read_frame().await.unwrap().unwrap(), b"hello");
1670        assert_eq!(t2.read_frame().await.unwrap().unwrap(), b"world");
1671        assert_eq!(t2.read_frame().await.unwrap(), None);
1672    }
1673
1674    #[test(tokio::test)]
1675    async fn synchronize_should_receive_all_frames_if_more_than_all_missing_from_other_side() {
1676        let (mut t1, mut t2) = FramedTransport::pair(100);
1677
1678        // Spawn a separate task to do synchronization simulation so we don't deadlock, and also
1679        // send a frame to indicate when finished so we can know when synchronization is done
1680        // during our test
1681        let task = tokio::spawn(async move {
1682            t2.synchronize().await.unwrap();
1683            t2.write_frame(Frame::new(b"done")).await.unwrap();
1684            t2
1685        });
1686
1687        // fake     (sent, received, available) = 3, 0, 2
1688        // expected (sent, received, available) = 0, 0, 0
1689        test_synchronize_stats(&mut t1, 2, 0, 2, 0, 0, 0).await;
1690
1691        // Send frames to fill the gap
1692        t1.write_frame(Frame::new(b"hello")).await.unwrap();
1693        t1.write_frame(Frame::new(b"world")).await.unwrap();
1694
1695        // Recieve the done indicator
1696        assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"done");
1697
1698        // Drop the transport such that the other side will get a definite termination
1699        drop(t1);
1700
1701        // Verify that the frame was captured on the other side
1702        let mut t2 = task.await.unwrap();
1703        assert_eq!(t2.read_frame().await.unwrap().unwrap(), b"hello");
1704        assert_eq!(t2.read_frame().await.unwrap().unwrap(), b"world");
1705        assert_eq!(t2.read_frame().await.unwrap(), None);
1706    }
1707
1708    #[test(tokio::test)]
1709    async fn synchronize_should_fail_if_connection_terminated_before_receiving_missing_frames() {
1710        let (mut t1, mut t2) = FramedTransport::pair(100);
1711
1712        // Spawn a separate task to do synchronization simulation so we don't deadlock, and also
1713        // send a frame to indicate when finished so we can know when synchronization is done
1714        // during our test
1715        let task = tokio::spawn(async move {
1716            t2.synchronize().await.unwrap();
1717            t2.write_frame(Frame::new(b"done")).await.unwrap();
1718            t2
1719        });
1720
1721        // fake     (sent, received, available) = 2, 0, 2
1722        // expected (sent, received, available) = 0, 0, 0
1723        test_synchronize_stats(&mut t1, 2, 0, 2, 0, 0, 0).await;
1724
1725        // Send one frame to fill the gap
1726        t1.write_frame(Frame::new(b"hello")).await.unwrap();
1727
1728        // Drop the transport to cause a failure
1729        drop(t1);
1730
1731        // Verify that the other side's synchronization failed
1732        task.await.unwrap_err();
1733    }
1734
1735    #[test(tokio::test)]
1736    async fn synchronize_should_fail_if_connection_terminated_while_waiting_for_frame_stats() {
1737        let (t1, mut t2) = FramedTransport::pair(100);
1738
1739        // Spawn a separate task to do synchronization simulation so we don't deadlock, and also
1740        // send a frame to indicate when finished so we can know when synchronization is done
1741        // during our test
1742        let task = tokio::spawn(async move {
1743            t2.synchronize().await.unwrap();
1744            t2.write_frame(Frame::new(b"done")).await.unwrap();
1745            t2
1746        });
1747
1748        // Drop the transport to cause a failure
1749        drop(t1);
1750
1751        // Verify that the other side's synchronization failed
1752        task.await.unwrap_err();
1753    }
1754
1755    #[test(tokio::test)]
1756    async fn synchronize_should_clear_any_prexisting_incoming_and_outgoing_data() {
1757        let (mut t1, mut t2) = FramedTransport::pair(100);
1758
1759        // Put some frames into the incoming and outgoing of our transport
1760        Frame::new(b"bad incoming").write(&mut t2.incoming);
1761        Frame::new(b"bad outgoing").write(&mut t2.outgoing);
1762
1763        // Configure the backup such that we have sent two frames
1764        t2.backup.push_frame(Frame::new(b"hello"));
1765        t2.backup.push_frame(Frame::new(b"world"));
1766        t2.backup.increment_sent_cnt();
1767        t2.backup.increment_sent_cnt();
1768
1769        // Spawn a separate task to do synchronization simulation so we don't deadlock, and also
1770        // send a frame to indicate when finished so we can know when synchronization is done
1771        // during our test
1772        let task = tokio::spawn(async move {
1773            t2.synchronize().await.unwrap();
1774            t2.write_frame(Frame::new(b"done")).await.unwrap();
1775            t2
1776        });
1777
1778        // fake     (sent, received, available) = 2, 0, 2
1779        // expected (sent, received, available) = 2, 0, 2
1780        test_synchronize_stats(&mut t1, 2, 0, 2, 2, 0, 2).await;
1781
1782        // Send frames to fill the gap
1783        t1.write_frame(Frame::new(b"one")).await.unwrap();
1784        t1.write_frame(Frame::new(b"two")).await.unwrap();
1785
1786        // Recieve both frames and then the done indicator
1787        assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"hello");
1788        assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"world");
1789        assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"done");
1790
1791        // Drop the transport such that the other side will get a definite termination
1792        drop(t1);
1793
1794        // Verify that the frame was captured on the other side
1795        let mut t2 = task.await.unwrap();
1796        assert_eq!(t2.read_frame().await.unwrap().unwrap(), b"one");
1797        assert_eq!(t2.read_frame().await.unwrap().unwrap(), b"two");
1798        assert_eq!(t2.read_frame().await.unwrap(), None);
1799    }
1800
1801    #[test(tokio::test)]
1802    async fn synchronize_should_not_increment_the_sent_frames_or_store_replayed_frames_in_the_backup(
1803    ) {
1804        let (mut t1, mut t2) = FramedTransport::pair(100);
1805
1806        // Configure the backup such that we have sent two frames
1807        t2.backup.push_frame(Frame::new(b"hello"));
1808        t2.backup.push_frame(Frame::new(b"world"));
1809        t2.backup.increment_sent_cnt();
1810        t2.backup.increment_sent_cnt();
1811
1812        // Spawn a separate task to do synchronization simulation so we don't deadlock, and also
1813        // send a frame to indicate when finished so we can know when synchronization is done
1814        // during our test
1815        let task = tokio::spawn(async move {
1816            t2.synchronize().await.unwrap();
1817
1818            t2.backup.freeze();
1819            t2.write_frame(Frame::new(b"done")).await.unwrap();
1820            t2.backup.unfreeze();
1821
1822            t2
1823        });
1824
1825        // fake     (sent, received, available) = 0, 0, 0
1826        // expected (sent, received, available) = 2, 0, 2
1827        test_synchronize_stats(&mut t1, 0, 0, 0, 2, 0, 2).await;
1828
1829        // Recieve both frames and then the done indicator
1830        assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"hello");
1831        assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"world");
1832        assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"done");
1833
1834        // Drop the transport such that the other side will get a definite termination
1835        drop(t1);
1836
1837        // Verify that the backup on the other side was unaltered by the frames being sent
1838        let t2 = task.await.unwrap();
1839        assert_eq!(t2.backup.sent_cnt(), 2, "Wrong sent cnt");
1840        assert_eq!(t2.backup.received_cnt(), 0, "Wrong received cnt");
1841        assert_eq!(t2.backup.frame_cnt(), 2, "Wrong frame cnt");
1842    }
1843
1844    #[test(tokio::test)]
1845    async fn synchronize_should_update_the_backup_received_cnt_to_match_other_side_sent() {
1846        let (mut t1, mut t2) = FramedTransport::pair(100);
1847
1848        // Spawn a separate task to do synchronization simulation so we don't deadlock, and also
1849        // send a frame to indicate when finished so we can know when synchronization is done
1850        // during our test
1851        let task = tokio::spawn(async move {
1852            t2.synchronize().await.unwrap();
1853
1854            t2.backup.freeze();
1855            t2.write_frame(Frame::new(b"done")).await.unwrap();
1856            t2.backup.unfreeze();
1857
1858            t2
1859        });
1860
1861        // fake     (sent, received, available) = 2, 0, 1
1862        // expected (sent, received, available) = 0, 0, 0
1863        test_synchronize_stats(&mut t1, 2, 0, 1, 0, 0, 0).await;
1864
1865        // Send frames to fill the gap
1866        t1.write_frame(Frame::new(b"hello")).await.unwrap();
1867
1868        // Recieve both frames and then the done indicator
1869        assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"done");
1870
1871        // Drop the transport such that the other side will get a definite termination
1872        drop(t1);
1873
1874        // Verify that the backup on the other side updated based on sent count and not available
1875        let t2 = task.await.unwrap();
1876        assert_eq!(t2.backup.sent_cnt(), 0, "Wrong sent cnt");
1877        assert_eq!(t2.backup.received_cnt(), 2, "Wrong received cnt");
1878        assert_eq!(t2.backup.frame_cnt(), 0, "Wrong frame cnt");
1879    }
1880
1881    #[test(tokio::test)]
1882    async fn synchronize_should_work_even_if_codec_changes_between_attempts() {
1883        let (mut t1, _t1_other) = FramedTransport::pair(100);
1884        let (mut t2, _t2_other) = FramedTransport::pair(100);
1885
1886        // Send some frames from each side
1887        t1.write_frame(Frame::new(b"hello")).await.unwrap();
1888        t1.write_frame(Frame::new(b"world")).await.unwrap();
1889        t2.write_frame(Frame::new(b"foo")).await.unwrap();
1890        t2.write_frame(Frame::new(b"bar")).await.unwrap();
1891
1892        // Drop the other transports, link our real transports together, and change the codec
1893        drop(_t1_other);
1894        drop(_t2_other);
1895        t1.link(&mut t2, 100);
1896        let codec = EncryptionCodec::new_xchacha20poly1305(Default::default());
1897        t1.codec = Box::new(codec.clone());
1898        t2.codec = Box::new(codec);
1899
1900        // Spawn a separate task to do synchronization so we don't deadlock
1901        let task = tokio::spawn(async move {
1902            t2.synchronize().await.unwrap();
1903            t2
1904        });
1905
1906        t1.synchronize().await.unwrap();
1907
1908        // Verify that we get the appropriate frames from both sides
1909        let mut t2 = task.await.unwrap();
1910        assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"foo");
1911        assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"bar");
1912        assert_eq!(t2.read_frame().await.unwrap().unwrap(), b"hello");
1913        assert_eq!(t2.read_frame().await.unwrap().unwrap(), b"world");
1914    }
1915
1916    #[test(tokio::test)]
1917    async fn handshake_should_configure_transports_with_matching_codec() {
1918        let (mut t1, mut t2) = FramedTransport::test_pair(100);
1919
1920        // NOTE: Spawn a separate task for one of our transports so we can communicate without
1921        //       deadlocking
1922        let task = tokio::spawn(async move {
1923            // Wait for handshake to complete
1924            t2.server_handshake().await.unwrap();
1925
1926            // Receive one frame and echo it back
1927            let frame = t2.read_frame().await.unwrap().unwrap();
1928            t2.write_frame(frame).await.unwrap();
1929        });
1930
1931        t1.client_handshake().await.unwrap();
1932
1933        // Verify that the transports can still communicate with one another
1934        t1.write_frame(b"hello world").await.unwrap();
1935        assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"hello world");
1936
1937        // Ensure that the other transport did not error
1938        task.await.unwrap();
1939    }
1940
1941    #[test(tokio::test)]
1942    async fn handshake_failing_should_ensure_existing_codec_remains() {
1943        let (mut t1, t2) = FramedTransport::test_pair(100);
1944
1945        // Set a different codec on our transport so we can verify it doesn't change
1946        t1.set_codec(Box::new(CustomCodec));
1947
1948        // Drop our transport on the other side to cause an immediate failure
1949        drop(t2);
1950
1951        // Ensure we detect the failure on handshake
1952        t1.client_handshake().await.unwrap_err();
1953
1954        // Verify that the codec did not reset to plain text by using the codec
1955        assert_eq!(t1.codec.encode(Frame::new(b"test")).unwrap(), b"encode");
1956        assert_eq!(t1.codec.decode(Frame::new(b"test")).unwrap(), b"decode");
1957    }
1958
1959    #[test(tokio::test)]
1960    async fn handshake_should_clear_any_intermittent_buffer_contents_prior_to_handshake_failing() {
1961        let (mut t1, t2) = FramedTransport::test_pair(100);
1962
1963        // Set a different codec on our transport so we can verify it doesn't change
1964        t1.set_codec(Box::new(CustomCodec));
1965
1966        // Drop our transport on the other side to cause an immediate failure
1967        drop(t2);
1968
1969        // Put some garbage in our buffers
1970        t1.incoming.extend_from_slice(b"garbage in");
1971        t1.outgoing.extend_from_slice(b"garbage out");
1972
1973        // Ensure we detect the failure on handshake
1974        t1.client_handshake().await.unwrap_err();
1975
1976        // Verify that the incoming and outgoing buffers are empty
1977        assert!(t1.incoming.is_empty());
1978        assert!(t1.outgoing.is_empty());
1979    }
1980
1981    #[test(tokio::test)]
1982    async fn handshake_should_clear_any_intermittent_buffer_contents_prior_to_handshake_succeeding()
1983    {
1984        let (mut t1, mut t2) = FramedTransport::test_pair(100);
1985
1986        // NOTE: Spawn a separate task for one of our transports so we can communicate without
1987        //       deadlocking
1988        let task = tokio::spawn(async move {
1989            // Wait for handshake to complete
1990            t2.server_handshake().await.unwrap();
1991
1992            // Receive one frame and echo it back
1993            let frame = t2.read_frame().await.unwrap().unwrap();
1994            t2.write_frame(frame).await.unwrap();
1995        });
1996
1997        // Put some garbage in our buffers
1998        t1.incoming.extend_from_slice(b"garbage in");
1999        t1.outgoing.extend_from_slice(b"garbage out");
2000
2001        t1.client_handshake().await.unwrap();
2002
2003        // Verify that the transports can still communicate with one another
2004        t1.write_frame(b"hello world").await.unwrap();
2005        assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"hello world");
2006
2007        // Ensure that the other transport did not error
2008        task.await.unwrap();
2009
2010        // Verify that the incoming and outgoing buffers are empty
2011        assert!(t1.incoming.is_empty());
2012        assert!(t1.outgoing.is_empty());
2013    }
2014
2015    #[test(tokio::test)]
2016    async fn handshake_for_client_should_fail_if_receives_unexpected_frame_instead_of_options() {
2017        let (mut t1, mut t2) = FramedTransport::test_pair(100);
2018
2019        // NOTE: Spawn a separate task for one of our transports so we can communicate without
2020        //       deadlocking
2021        let task = tokio::spawn(async move {
2022            t2.write_frame(b"not a valid frame for handshake")
2023                .await
2024                .unwrap();
2025        });
2026
2027        // Ensure we detect the failure on handshake
2028        let err = t1.client_handshake().await.unwrap_err();
2029        assert_eq!(err.kind(), io::ErrorKind::InvalidData);
2030
2031        // Ensure that the other transport did not error
2032        task.await.unwrap();
2033    }
2034
2035    #[test(tokio::test)]
2036    async fn handshake_for_client_should_fail_unable_to_send_codec_choice_to_other_side() {
2037        let (mut t1, mut t2) = FramedTransport::test_pair(100);
2038
2039        #[derive(Debug, Serialize, Deserialize)]
2040        struct Options {
2041            compression_types: Vec<CompressionType>,
2042            encryption_types: Vec<EncryptionType>,
2043        }
2044
2045        // NOTE: Spawn a separate task for one of our transports so we can communicate without
2046        //       deadlocking
2047        let task = tokio::spawn(async move {
2048            // Send options, and then quit so the client side will fail
2049            t2.write_frame_for(&Options {
2050                compression_types: Vec::new(),
2051                encryption_types: Vec::new(),
2052            })
2053            .await
2054            .unwrap();
2055        });
2056
2057        // Ensure we detect the failure on handshake
2058        let err = t1.client_handshake().await.unwrap_err();
2059        assert_eq!(err.kind(), io::ErrorKind::WriteZero);
2060
2061        // Ensure that the other transport did not error
2062        task.await.unwrap();
2063    }
2064
2065    #[test(tokio::test)]
2066    async fn handshake_for_client_should_fail_if_unable_to_receive_key_exchange_data_from_other_side(
2067    ) {
2068        #[derive(Debug, Serialize, Deserialize)]
2069        struct Options {
2070            compression_types: Vec<CompressionType>,
2071            encryption_types: Vec<EncryptionType>,
2072        }
2073
2074        let (mut t1, mut t2) = FramedTransport::test_pair(100);
2075
2076        // Go ahead and queue up a choice, and then queue up invalid key exchange data
2077        t2.write_frame_for(&Options {
2078            compression_types: CompressionType::known_variants().to_vec(),
2079            encryption_types: EncryptionType::known_variants().to_vec(),
2080        })
2081        .await
2082        .unwrap();
2083
2084        t2.write_frame(b"not valid key exchange data")
2085            .await
2086            .unwrap();
2087
2088        // Ensure we detect the failure on handshake
2089        let err = t1.client_handshake().await.unwrap_err();
2090        assert_eq!(err.kind(), io::ErrorKind::InvalidData);
2091    }
2092
2093    #[test(tokio::test)]
2094    async fn handshake_for_server_should_fail_if_receives_unexpected_frame_instead_of_choice() {
2095        let (mut t1, mut t2) = FramedTransport::test_pair(100);
2096
2097        // NOTE: Spawn a separate task for one of our transports so we can communicate without
2098        //       deadlocking
2099        let task = tokio::spawn(async move {
2100            t2.write_frame(b"not a valid frame for handshake")
2101                .await
2102                .unwrap();
2103        });
2104
2105        // Ensure we detect the failure on handshake
2106        let err = t1.server_handshake().await.unwrap_err();
2107        assert_eq!(err.kind(), io::ErrorKind::InvalidData);
2108
2109        // Ensure that the other transport did not error
2110        task.await.unwrap();
2111    }
2112
2113    #[test(tokio::test)]
2114    async fn handshake_for_server_should_fail_unable_to_send_codec_options_to_other_side() {
2115        let (mut t1, t2) = FramedTransport::test_pair(100);
2116
2117        // Drop our other transport to ensure that nothing can be sent to it
2118        drop(t2);
2119
2120        // Ensure we detect the failure on handshake
2121        let err = t1.server_handshake().await.unwrap_err();
2122        assert_eq!(err.kind(), io::ErrorKind::WriteZero);
2123    }
2124
2125    #[test(tokio::test)]
2126    async fn handshake_for_server_should_fail_if_selected_codec_choice_uses_an_unknown_compression_type(
2127    ) {
2128        #[derive(Debug, Serialize, Deserialize)]
2129        struct Choice {
2130            compression_level: Option<CompressionLevel>,
2131            compression_type: Option<CompressionType>,
2132            encryption_type: Option<EncryptionType>,
2133        }
2134
2135        let (mut t1, mut t2) = FramedTransport::test_pair(100);
2136
2137        // Go ahead and queue up an improper response
2138        t2.write_frame_for(&Choice {
2139            compression_level: None,
2140            compression_type: Some(CompressionType::Unknown),
2141            encryption_type: None,
2142        })
2143        .await
2144        .unwrap();
2145
2146        // Ensure we detect the failure on handshake
2147        let err = t1.server_handshake().await.unwrap_err();
2148        assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
2149    }
2150
2151    #[test(tokio::test)]
2152    async fn handshake_for_server_should_fail_if_selected_codec_choice_uses_an_unknown_encryption_type(
2153    ) {
2154        #[derive(Debug, Serialize, Deserialize)]
2155        struct Choice {
2156            compression_level: Option<CompressionLevel>,
2157            compression_type: Option<CompressionType>,
2158            encryption_type: Option<EncryptionType>,
2159        }
2160
2161        let (mut t1, mut t2) = FramedTransport::test_pair(100);
2162
2163        // Go ahead and queue up an improper response
2164        t2.write_frame_for(&Choice {
2165            compression_level: None,
2166            compression_type: None,
2167            encryption_type: Some(EncryptionType::Unknown),
2168        })
2169        .await
2170        .unwrap();
2171
2172        // Ensure we detect the failure on handshake
2173        let err = t1.server_handshake().await.unwrap_err();
2174        assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
2175    }
2176
2177    #[test(tokio::test)]
2178    async fn handshake_for_server_should_fail_if_unable_to_receive_key_exchange_data_from_other_side(
2179    ) {
2180        #[derive(Debug, Serialize, Deserialize)]
2181        struct Choice {
2182            compression_level: Option<CompressionLevel>,
2183            compression_type: Option<CompressionType>,
2184            encryption_type: Option<EncryptionType>,
2185        }
2186
2187        let (mut t1, mut t2) = FramedTransport::test_pair(100);
2188
2189        // Go ahead and queue up a choice, and then queue up invalid key exchange data
2190        t2.write_frame_for(&Choice {
2191            compression_level: None,
2192            compression_type: None,
2193            encryption_type: Some(EncryptionType::XChaCha20Poly1305),
2194        })
2195        .await
2196        .unwrap();
2197
2198        t2.write_frame(b"not valid key exchange data")
2199            .await
2200            .unwrap();
2201
2202        // Ensure we detect the failure on handshake
2203        let err = t1.server_handshake().await.unwrap_err();
2204        assert_eq!(err.kind(), io::ErrorKind::InvalidData);
2205    }
2206
2207    #[test(tokio::test)]
2208    async fn exchange_keys_should_fail_if_unable_to_send_exchange_data_to_other_side() {
2209        let (mut t1, t2) = FramedTransport::test_pair(100);
2210
2211        // Drop the other side to ensure that the exchange fails at the beginning
2212        drop(t2);
2213
2214        // Perform key exchange and verify error is as expected
2215        assert_eq!(
2216            t1.exchange_keys().await.unwrap_err().kind(),
2217            io::ErrorKind::WriteZero
2218        );
2219    }
2220
2221    #[test(tokio::test)]
2222    async fn exchange_keys_should_fail_if_received_invalid_exchange_data() {
2223        let (mut t1, mut t2) = FramedTransport::test_pair(100);
2224
2225        // Queue up an invalid exchange response
2226        t2.write_frame(b"some invalid frame").await.unwrap();
2227
2228        // Perform key exchange and verify error is as expected
2229        assert_eq!(
2230            t1.exchange_keys().await.unwrap_err().kind(),
2231            io::ErrorKind::InvalidData
2232        );
2233    }
2234
2235    #[test(tokio::test)]
2236    async fn exchange_keys_should_return_shared_secret_key_if_successful() {
2237        let (mut t1, mut t2) = FramedTransport::test_pair(100);
2238
2239        // Spawn a task to avoid deadlocking
2240        let task = tokio::spawn(async move { t2.exchange_keys().await.unwrap() });
2241
2242        // Perform key exchange
2243        let key = t1.exchange_keys().await.unwrap();
2244
2245        // Validate that the keys on both sides match
2246        assert_eq!(key, task.await.unwrap());
2247    }
2248}