1use std::future::Future;
2use std::time::Duration;
3use std::{fmt, io};
4
5use async_trait::async_trait;
6use bytes::{Buf, BytesMut};
7use log::*;
8use serde::de::DeserializeOwned;
9use serde::{Deserialize, Serialize};
10
11use super::{InmemoryTransport, Interest, Ready, Reconnectable, Transport};
12use crate::common::{utils, SecretKey32};
13
14mod backup;
15mod codec;
16mod exchange;
17mod frame;
18mod handshake;
19
20pub use backup::*;
21pub use codec::*;
22pub use exchange::*;
23pub use frame::*;
24pub use handshake::*;
25
26const READ_BUF_SIZE: usize = 8 * 1024;
28
29const SLEEP_DURATION: Duration = Duration::from_millis(1);
31
32#[derive(Clone)]
37pub struct FramedTransport<T> {
38 inner: T,
40
41 codec: BoxedCodec,
43
44 incoming: BytesMut,
46
47 outgoing: BytesMut,
49
50 pub backup: Backup,
52}
53
54impl<T> FramedTransport<T> {
55 pub fn new(inner: T, codec: BoxedCodec) -> Self {
56 Self {
57 inner,
58 codec,
59 incoming: BytesMut::with_capacity(READ_BUF_SIZE * 2),
60 outgoing: BytesMut::with_capacity(READ_BUF_SIZE * 2),
61 backup: Backup::new(),
62 }
63 }
64
65 pub fn plain(inner: T) -> Self {
67 Self::new(inner, Box::new(PlainCodec::new()))
68 }
69
70 pub fn set_codec(&mut self, codec: BoxedCodec) {
78 self.codec = codec;
79 }
80
81 pub fn codec(&self) -> &dyn Codec {
88 self.codec.as_ref()
89 }
90
91 pub fn mut_codec(&mut self) -> &mut dyn Codec {
98 self.codec.as_mut()
99 }
100
101 pub fn clear(&mut self) {
103 self.incoming.clear();
104 self.outgoing.clear();
105 }
106
107 pub fn as_inner(&self) -> &T {
109 &self.inner
110 }
111
112 pub fn as_mut_inner(&mut self) -> &mut T {
114 &mut self.inner
115 }
116
117 pub fn into_inner(self) -> T {
119 self.inner
120 }
121}
122
123impl<T> fmt::Debug for FramedTransport<T> {
124 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
125 f.debug_struct("FramedTransport")
126 .field("incoming", &self.incoming)
127 .field("outgoing", &self.outgoing)
128 .field("backup", &self.backup)
129 .finish()
130 }
131}
132
133impl<T: Transport + 'static> FramedTransport<T> {
134 pub fn into_boxed(self) -> FramedTransport<Box<dyn Transport>> {
136 FramedTransport {
137 inner: Box::new(self.inner),
138 codec: self.codec,
139 incoming: self.incoming,
140 outgoing: self.outgoing,
141 backup: self.backup,
142 }
143 }
144}
145
146impl<T: Transport> FramedTransport<T> {
147 pub async fn ready(&self, interest: Interest) -> io::Result<Ready> {
149 let ready = if interest.is_readable() && Frame::available(&self.incoming) {
154 Ready::READABLE
155 } else {
156 Ready::EMPTY
157 };
158
159 if !interest.is_writable() && ready.is_readable() {
162 return Ok(ready);
163 }
164
165 Transport::ready(&self.inner, interest)
168 .await
169 .map(|r| r | ready)
170 }
171
172 pub async fn readable(&self) -> io::Result<()> {
176 let _ = self.ready(Interest::READABLE).await?;
177 Ok(())
178 }
179
180 pub async fn writeable(&self) -> io::Result<()> {
184 let _ = self.ready(Interest::WRITABLE).await?;
185 Ok(())
186 }
187
188 pub async fn readable_or_writeable(&self) -> io::Result<Ready> {
190 self.ready(Interest::READABLE | Interest::WRITABLE).await
191 }
192
193 pub fn try_flush(&mut self) -> io::Result<usize> {
206 let mut bytes_written = 0;
207
208 while !self.outgoing.is_empty() {
210 match self.inner.try_write(self.outgoing.as_ref()) {
211 Ok(0) => return Err(io::Error::from(io::ErrorKind::WriteZero)),
213
214 Ok(n) => {
216 self.outgoing.advance(n);
217 bytes_written += n;
218 }
219
220 Err(x) => return Err(x),
222 }
223 }
224
225 Ok(bytes_written)
226 }
227
228 pub async fn flush(&mut self) -> io::Result<()> {
232 while !self.outgoing.is_empty() {
233 self.writeable().await?;
234 match self.try_flush() {
235 Err(x) if x.kind() == io::ErrorKind::WouldBlock => {
236 tokio::time::sleep(SLEEP_DURATION).await
238 }
239 Err(x) => return Err(x),
240 Ok(_) => return Ok(()),
241 }
242 }
243
244 Ok(())
245 }
246
247 pub fn try_read_frame(&mut self) -> io::Result<Option<OwnedFrame>> {
256 macro_rules! read_next_frame {
260 () => {{
261 match Frame::read(&mut self.incoming) {
262 None => (),
263 Some(frame) => {
264 if frame.is_nonempty() {
265 self.backup.increment_received_cnt();
266 }
267 return Ok(Some(self.codec.decode(frame)?.into_owned()));
268 }
269 }
270 }};
271 }
272
273 if !self.incoming.is_empty() {
280 read_next_frame!();
281 }
282
283 let mut buf = [0; READ_BUF_SIZE];
285
286 loop {
287 match self.inner.try_read(&mut buf) {
288 Ok(0) if self.incoming.is_empty() => return Ok(None),
293 Ok(0) => return Err(io::Error::from(io::ErrorKind::UnexpectedEof)),
294
295 Ok(n) => {
298 self.incoming.extend_from_slice(&buf[..n]);
299 read_next_frame!();
300 }
301
302 Err(x) => return Err(x),
304 }
305 }
306 }
307
308 pub fn try_read_frame_as<D: DeserializeOwned>(&mut self) -> io::Result<Option<D>> {
312 match self.try_read_frame() {
313 Ok(Some(frame)) => Ok(Some(utils::deserialize_from_slice(frame.as_item())?)),
314 Ok(None) => Ok(None),
315 Err(x) => Err(x),
316 }
317 }
318
319 pub async fn read_frame(&mut self) -> io::Result<Option<OwnedFrame>> {
325 loop {
326 self.readable().await?;
327
328 match self.try_read_frame() {
329 Err(x) if x.kind() == io::ErrorKind::WouldBlock => {
330 tokio::time::sleep(SLEEP_DURATION).await
332 }
333 x => return x,
334 }
335 }
336 }
337
338 pub async fn read_frame_as<D: DeserializeOwned>(&mut self) -> io::Result<Option<D>> {
342 match self.read_frame().await {
343 Ok(Some(frame)) => Ok(Some(utils::deserialize_from_slice(frame.as_item())?)),
344 Ok(None) => Ok(None),
345 Err(x) => Err(x),
346 }
347 }
348
349 pub fn try_write_frame<'a, F>(&mut self, frame: F) -> io::Result<()>
360 where
361 F: TryInto<Frame<'a>>,
362 F::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
363 {
364 let frame = frame
366 .try_into()
367 .map_err(|x| io::Error::new(io::ErrorKind::InvalidInput, x))?;
368
369 self.codec
371 .encode(frame.as_borrowed())?
372 .write(&mut self.outgoing);
373
374 if frame.is_nonempty() {
376 self.backup.increment_sent_cnt();
378
379 self.backup.push_frame(frame);
382 }
383
384 self.try_flush()?;
386
387 Ok(())
388 }
389
390 pub fn try_write_frame_for<D: Serialize>(&mut self, value: &D) -> io::Result<()> {
394 let data = utils::serialize_to_vec(value)?;
395 self.try_write_frame(data)
396 }
397
398 pub async fn write_frame<'a, F>(&mut self, frame: F) -> io::Result<()>
406 where
407 F: TryInto<Frame<'a>>,
408 F::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
409 {
410 self.writeable().await?;
411
412 match self.try_write_frame(frame) {
413 Err(x) if x.kind() == io::ErrorKind::WouldBlock => loop {
415 self.writeable().await?;
416 match self.try_flush() {
417 Err(x) if x.kind() == io::ErrorKind::WouldBlock => {
418 tokio::time::sleep(SLEEP_DURATION).await
420 }
421 Err(x) => return Err(x),
422 Ok(_) => return Ok(()),
423 }
424 },
425
426 x => x,
428 }
429 }
430
431 pub async fn write_frame_for<D: Serialize>(&mut self, value: &D) -> io::Result<()> {
435 let data = utils::serialize_to_vec(value)?;
436 self.write_frame(data).await
437 }
438
439 pub async fn do_frozen<F, X>(&mut self, mut f: F) -> io::Result<()>
441 where
442 F: FnMut(&mut Self) -> X,
443 X: Future<Output = io::Result<()>>,
444 {
445 let is_frozen = self.backup.is_frozen();
446 self.backup.freeze();
447 let result = f(self).await;
448 self.backup.set_frozen(is_frozen);
449 result
450 }
451
452 pub async fn synchronize(&mut self) -> io::Result<()> {
462 async fn synchronize_impl<T: Transport>(
463 this: &mut FramedTransport<T>,
464 backup: &mut Backup,
465 ) -> io::Result<()> {
466 type Stats = (u64, u64, u64);
467
468 let sent_cnt: u64 = backup.sent_cnt();
470 let received_cnt: u64 = backup.received_cnt();
471 let available_cnt: u64 = backup
472 .frame_cnt()
473 .try_into()
474 .expect("Cannot case usize to u64");
475
476 this.clear();
478
479 trace!(
483 "Stats: sent = {sent_cnt}, received = {received_cnt}, available = {available_cnt}"
484 );
485 this.write_frame_for(&(sent_cnt, received_cnt, available_cnt))
486 .await?;
487 let (other_sent_cnt, other_received_cnt, other_available_cnt) =
488 this.read_frame_as::<Stats>().await?.ok_or_else(|| {
489 io::Error::new(
490 io::ErrorKind::UnexpectedEof,
491 "Transport terminated before getting replay stats",
492 )
493 })?;
494 trace!("Other stats: sent = {other_sent_cnt}, received = {other_received_cnt}, available = {other_available_cnt}");
495
496 let resend_cnt = std::cmp::min(
499 if sent_cnt > other_received_cnt {
500 sent_cnt - other_received_cnt
501 } else {
502 0
503 },
504 available_cnt,
505 );
506
507 let expected_cnt = std::cmp::min(
510 if received_cnt < other_sent_cnt {
511 other_sent_cnt - received_cnt
512 } else {
513 0
514 },
515 other_available_cnt,
516 );
517
518 trace!("Reducing internal replay frames to {resend_cnt}");
520 backup.truncate_front(resend_cnt.try_into().expect("Cannot cast usize to u64"));
521
522 debug!("Sending {resend_cnt} frames");
523 for frame in backup.frames() {
524 this.try_write_frame(frame.as_borrowed())?;
525 }
526 this.flush().await?;
527
528 debug!("Waiting for {expected_cnt} frames");
534 for i in 0..expected_cnt {
535 let frame = this.read_frame().await?.ok_or_else(|| {
536 io::Error::new(
537 io::ErrorKind::UnexpectedEof,
538 format!(
539 "Transport terminated before getting frame {}/{expected_cnt}",
540 i + 1
541 ),
542 )
543 })?;
544
545 this.codec.encode(frame)?.write(&mut this.incoming);
548 }
549
550 if backup.received_cnt() != other_sent_cnt {
553 warn!(
554 "Backup received count ({}) != other sent count ({}), so resetting to match",
555 backup.received_cnt(),
556 other_sent_cnt
557 );
558 backup.set_received_cnt(other_sent_cnt);
559 }
560
561 Ok(())
562 }
563
564 let mut backup = std::mem::take(&mut self.backup);
566
567 let result = synchronize_impl(self, &mut backup).await;
569
570 self.backup = backup;
572
573 result
574 }
575
576 #[inline]
581 pub async fn from_client_handshake(transport: T) -> io::Result<Self> {
582 let mut transport = Self::plain(transport);
583 transport.client_handshake().await?;
584 Ok(transport)
585 }
586
587 pub async fn client_handshake(&mut self) -> io::Result<()> {
591 self.handshake(Handshake::client()).await
592 }
593
594 #[inline]
599 pub async fn from_server_handshake(transport: T) -> io::Result<Self> {
600 let mut transport = Self::plain(transport);
601 transport.server_handshake().await?;
602 Ok(transport)
603 }
604
605 pub async fn server_handshake(&mut self) -> io::Result<()> {
609 self.handshake(Handshake::server()).await
610 }
611
612 pub async fn handshake(&mut self, handshake: Handshake) -> io::Result<()> {
648 let old_codec = std::mem::replace(&mut self.codec, Box::new(PlainCodec::new()));
653 self.clear();
654
655 let backup = std::mem::take(&mut self.backup);
657
658 match self.handshake_impl(handshake).await {
662 Ok(codec) => {
663 self.set_codec(codec);
664 self.backup = backup;
665 Ok(())
666 }
667 Err(x) => {
668 self.set_codec(old_codec);
669 self.clear();
670 self.backup = backup;
671 Err(x)
672 }
673 }
674 }
675
676 async fn handshake_impl(&mut self, handshake: Handshake) -> io::Result<BoxedCodec> {
677 #[derive(Debug, Serialize, Deserialize)]
678 struct Choice {
679 compression_level: Option<CompressionLevel>,
680 compression_type: Option<CompressionType>,
681 encryption_type: Option<EncryptionType>,
682 }
683
684 #[derive(Debug, Serialize, Deserialize)]
685 struct Options {
686 compression_types: Vec<CompressionType>,
687 encryption_types: Vec<EncryptionType>,
688 }
689
690 let log_label = if handshake.is_client() {
692 "Handshake | Client"
693 } else {
694 "Handshake | Server"
695 };
696
697 let choice = match handshake {
699 Handshake::Client {
700 preferred_compression_type,
701 preferred_compression_level,
702 preferred_encryption_type,
703 } => {
704 debug!("[{log_label}] Waiting on options");
706 let options = self.read_frame_as::<Options>().await?.ok_or_else(|| {
707 io::Error::new(
708 io::ErrorKind::UnexpectedEof,
709 "Transport closed early while waiting for options",
710 )
711 })?;
712
713 debug!("[{log_label}] Selecting from options: {options:?}");
715 let choice = Choice {
716 compression_type: preferred_compression_type
719 .filter(|ty| options.compression_types.contains(ty)),
720
721 compression_level: preferred_compression_level,
723
724 encryption_type: preferred_encryption_type
727 .filter(|ty| options.encryption_types.contains(ty))
728 .or_else(|| {
729 options
730 .encryption_types
731 .iter()
732 .find(|ty| !ty.is_unknown())
733 .copied()
734 }),
735 };
736
737 debug!("[{log_label}] Reporting choice: {choice:?}");
739 self.write_frame_for(&choice).await?;
740
741 choice
742 }
743 Handshake::Server {
744 compression_types,
745 encryption_types,
746 } => {
747 let options = Options {
748 compression_types: compression_types.to_vec(),
749 encryption_types: encryption_types.to_vec(),
750 };
751
752 debug!("[{log_label}] Sending options: {options:?}");
754 self.write_frame_for(&options).await?;
755
756 debug!("[{log_label}] Waiting on choice");
758 self.read_frame_as::<Choice>().await?.ok_or_else(|| {
759 io::Error::new(
760 io::ErrorKind::UnexpectedEof,
761 "Transport closed early while waiting for choice",
762 )
763 })?
764 }
765 };
766
767 debug!("[{log_label}] Building compression & encryption codecs based on {choice:?}");
768 let compression_level = choice.compression_level.unwrap_or_default();
769
770 let compression_codec = choice
772 .compression_type
773 .map(|ty| ty.new_codec(compression_level))
774 .transpose()?;
775
776 let encryption_codec = match choice.encryption_type {
779 Some(EncryptionType::Unknown) => {
781 return Err(io::Error::new(
782 io::ErrorKind::InvalidInput,
783 "Unknown compression type",
784 ))
785 }
786 Some(ty) => {
787 let key = self.exchange_keys_impl(log_label).await?;
788 Some(ty.new_codec(key.unprotected_as_bytes())?)
789 }
790 None => None,
791 };
792
793 trace!("[{log_label}] Bundling codecs");
795 let codec: BoxedCodec = match (compression_codec, encryption_codec) {
796 (Some(c), Some(e)) => Box::new(ChainCodec::new(e, c)),
799
800 (Some(c), None) => Box::new(c),
802
803 (None, Some(e)) => Box::new(e),
805
806 (None, None) => Box::new(PlainCodec::new()),
808 };
809
810 Ok(codec)
811 }
812
813 pub async fn exchange_keys(&mut self) -> io::Result<SecretKey32> {
816 self.exchange_keys_impl("").await
817 }
818
819 async fn exchange_keys_impl(&mut self, label: &str) -> io::Result<SecretKey32> {
820 let log_label = if label.is_empty() {
821 String::new()
822 } else {
823 format!("[{label}] ")
824 };
825
826 #[derive(Serialize, Deserialize)]
827 struct KeyExchangeData {
828 #[serde(with = "serde_bytes")]
830 public_key: PublicKeyBytes,
831
832 #[serde(with = "serde_bytes")]
834 salt: Salt,
835 }
836
837 debug!("{log_label}Exchanging public key and salt");
838 let exchange = KeyExchange::default();
839 self.write_frame_for(&KeyExchangeData {
840 public_key: exchange.pk_bytes(),
841 salt: *exchange.salt(),
842 })
843 .await?;
844
845 trace!("{log_label}Waiting on public key and salt from other side");
850 let data = self
851 .read_frame_as::<KeyExchangeData>()
852 .await?
853 .ok_or_else(|| {
854 io::Error::new(
855 io::ErrorKind::UnexpectedEof,
856 "Transport closed early while waiting for key data",
857 )
858 })?;
859
860 trace!("{log_label}Deriving shared secret key");
861 let key = exchange.derive_shared_secret(data.public_key, data.salt)?;
862 Ok(key)
863 }
864}
865
866#[async_trait]
867impl<T> Reconnectable for FramedTransport<T>
868where
869 T: Transport,
870{
871 async fn reconnect(&mut self) -> io::Result<()> {
872 Reconnectable::reconnect(&mut self.inner).await
873 }
874}
875
876impl FramedTransport<InmemoryTransport> {
877 pub fn pair(
882 buffer: usize,
883 ) -> (
884 FramedTransport<InmemoryTransport>,
885 FramedTransport<InmemoryTransport>,
886 ) {
887 let (a, b) = InmemoryTransport::pair(buffer);
888 let a = FramedTransport::new(a, Box::new(PlainCodec::new()));
889 let b = FramedTransport::new(b, Box::new(PlainCodec::new()));
890 (a, b)
891 }
892
893 pub fn link(&mut self, other: &mut Self, buffer: usize) {
895 self.inner.link(&mut other.inner, buffer)
896 }
897}
898
899#[cfg(test)]
900impl FramedTransport<InmemoryTransport> {
901 pub fn test_pair(
903 buffer: usize,
904 ) -> (
905 FramedTransport<InmemoryTransport>,
906 FramedTransport<InmemoryTransport>,
907 ) {
908 Self::pair(buffer)
909 }
910}
911
912#[cfg(test)]
913mod tests {
914 use bytes::BufMut;
915 use test_log::test;
916
917 use super::*;
918 use crate::common::TestTransport;
919
920 #[derive(Clone, Debug, PartialEq, Eq)]
922 struct OkCodec;
923
924 impl Codec for OkCodec {
925 fn encode<'a>(&mut self, frame: Frame<'a>) -> io::Result<Frame<'a>> {
926 Ok(frame)
927 }
928
929 fn decode<'a>(&mut self, frame: Frame<'a>) -> io::Result<Frame<'a>> {
930 Ok(frame)
931 }
932 }
933
934 #[derive(Clone, Debug, PartialEq, Eq)]
936 struct ErrCodec;
937
938 impl Codec for ErrCodec {
939 fn encode<'a>(&mut self, _frame: Frame<'a>) -> io::Result<Frame<'a>> {
940 Err(io::Error::from(io::ErrorKind::Other))
941 }
942
943 fn decode<'a>(&mut self, _frame: Frame<'a>) -> io::Result<Frame<'a>> {
944 Err(io::Error::from(io::ErrorKind::Other))
945 }
946 }
947
948 #[derive(Clone)]
950 struct CustomCodec;
951
952 impl Codec for CustomCodec {
953 fn encode<'a>(&mut self, _: Frame<'a>) -> io::Result<Frame<'a>> {
954 Ok(Frame::new(b"encode"))
955 }
956
957 fn decode<'a>(&mut self, _: Frame<'a>) -> io::Result<Frame<'a>> {
958 Ok(Frame::new(b"decode"))
959 }
960 }
961
962 type SimulateTryReadFn = Box<dyn Fn(&mut [u8]) -> io::Result<usize> + Send + Sync>;
963
964 fn simulate_try_read(
971 frames: Vec<Frame>,
972 step: usize,
973 block_on: impl Fn(usize) -> bool + Send + Sync + 'static,
974 ) -> SimulateTryReadFn {
975 use std::sync::atomic::{AtomicUsize, Ordering};
976
977 let data = {
979 let mut buf = BytesMut::new();
980
981 for frame in frames {
982 frame.write(&mut buf);
983 }
984
985 buf.to_vec()
986 };
987
988 let idx = AtomicUsize::new(0);
989 let cnt = AtomicUsize::new(0);
990
991 Box::new(move |buf| {
992 if block_on(cnt.fetch_add(1, Ordering::Relaxed)) {
993 return Err(io::Error::from(io::ErrorKind::WouldBlock));
994 }
995
996 let start = idx.fetch_add(step, Ordering::Relaxed);
997 let end = start + step;
998 let end = if end > data.len() { data.len() } else { end };
999 let len = if start > end { 0 } else { end - start };
1000
1001 buf[..len].copy_from_slice(&data[start..end]);
1002 Ok(len)
1003 })
1004 }
1005
1006 #[test]
1007 fn try_read_frame_should_return_would_block_if_fails_to_read_frame_before_blocking() {
1008 let mut transport = FramedTransport::new(
1010 TestTransport {
1011 f_try_read: Box::new(|_| Err(io::Error::from(io::ErrorKind::WouldBlock))),
1012 f_ready: Box::new(|_| Ok(Ready::READABLE)),
1013 ..Default::default()
1014 },
1015 Box::new(OkCodec),
1016 );
1017 assert_eq!(
1018 transport.try_read_frame().unwrap_err().kind(),
1019 io::ErrorKind::WouldBlock
1020 );
1021
1022 let mut transport = FramedTransport::new(
1024 TestTransport {
1025 f_try_read: simulate_try_read(vec![Frame::new(b"some data")], 1, |cnt| cnt == 1),
1026 f_ready: Box::new(|_| Ok(Ready::READABLE)),
1027 ..Default::default()
1028 },
1029 Box::new(OkCodec),
1030 );
1031 assert_eq!(
1032 transport.try_read_frame().unwrap_err().kind(),
1033 io::ErrorKind::WouldBlock
1034 );
1035 }
1036
1037 #[test]
1038 fn try_read_frame_should_return_error_if_encountered_error_with_reading_bytes() {
1039 let mut transport = FramedTransport::new(
1040 TestTransport {
1041 f_try_read: Box::new(|_| Err(io::Error::from(io::ErrorKind::NotConnected))),
1042 f_ready: Box::new(|_| Ok(Ready::READABLE)),
1043 ..Default::default()
1044 },
1045 Box::new(OkCodec),
1046 );
1047 assert_eq!(
1048 transport.try_read_frame().unwrap_err().kind(),
1049 io::ErrorKind::NotConnected
1050 );
1051 }
1052
1053 #[test]
1054 fn try_read_frame_should_return_error_if_encountered_error_during_decode() {
1055 let mut transport = FramedTransport::new(
1056 TestTransport {
1057 f_try_read: simulate_try_read(vec![Frame::new(b"some data")], 1, |_| false),
1058 f_ready: Box::new(|_| Ok(Ready::READABLE)),
1059 ..Default::default()
1060 },
1061 Box::new(ErrCodec),
1062 );
1063 assert_eq!(
1064 transport.try_read_frame().unwrap_err().kind(),
1065 io::ErrorKind::Other
1066 );
1067 }
1068
1069 #[test]
1070 fn try_read_frame_should_return_next_available_frame() {
1071 let data = {
1072 let mut data = BytesMut::new();
1073 Frame::new(b"hello world").write(&mut data);
1074 data.freeze()
1075 };
1076
1077 let mut transport = FramedTransport::new(
1078 TestTransport {
1079 f_try_read: Box::new(move |buf| {
1080 buf[..data.len()].copy_from_slice(data.as_ref());
1081 Ok(data.len())
1082 }),
1083 f_ready: Box::new(|_| Ok(Ready::READABLE)),
1084 ..Default::default()
1085 },
1086 Box::new(OkCodec),
1087 );
1088 assert_eq!(transport.try_read_frame().unwrap().unwrap(), b"hello world");
1089 }
1090
1091 #[test]
1092 fn try_read_frame_should_return_next_available_frame_if_already_in_incoming_buffer() {
1093 let data = {
1095 let mut data = BytesMut::new();
1096 Frame::new(b"hello world").write(&mut data);
1097 Frame::new(b"hello again").write(&mut data);
1098 data.freeze()
1099 };
1100
1101 let mut transport = FramedTransport::new(
1105 TestTransport {
1106 f_try_read: Box::new(move |buf| {
1107 static mut CNT: usize = 0;
1108 unsafe {
1109 CNT += 1;
1110 if CNT == 2 {
1111 Err(io::Error::from(io::ErrorKind::WouldBlock))
1112 } else {
1113 let n = data.len();
1114 buf[..data.len()].copy_from_slice(data.as_ref());
1115 Ok(n)
1116 }
1117 }
1118 }),
1119 f_ready: Box::new(|_| Ok(Ready::READABLE)),
1120 ..Default::default()
1121 },
1122 Box::new(OkCodec),
1123 );
1124
1125 assert_eq!(transport.try_read_frame().unwrap().unwrap(), b"hello world");
1127
1128 assert_eq!(transport.try_read_frame().unwrap().unwrap(), b"hello again");
1130 }
1131
1132 #[test]
1133 fn try_read_frame_should_keep_reading_until_a_frame_is_found() {
1134 const STEP_SIZE: usize = Frame::HEADER_SIZE + 7;
1135
1136 let mut transport = FramedTransport::new(
1137 TestTransport {
1138 f_try_read: simulate_try_read(
1139 vec![Frame::new(b"hello world"), Frame::new(b"test hello")],
1140 STEP_SIZE,
1141 |_| false,
1142 ),
1143 f_ready: Box::new(|_| Ok(Ready::READABLE)),
1144 ..Default::default()
1145 },
1146 Box::new(OkCodec),
1147 );
1148 assert_eq!(transport.try_read_frame().unwrap().unwrap(), b"hello world");
1149
1150 assert_eq!(
1153 transport.incoming.to_vec(),
1154 [0, 0, 0, 0, 0, 0, 0, 10, b't', b'e', b's']
1155 );
1156 }
1157
1158 #[test]
1159 fn try_write_frame_should_return_would_block_if_fails_to_write_frame_before_blocking() {
1160 let mut transport = FramedTransport::new(
1161 TestTransport {
1162 f_try_write: Box::new(|_| Err(io::Error::from(io::ErrorKind::WouldBlock))),
1163 f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
1164 ..Default::default()
1165 },
1166 Box::new(OkCodec),
1167 );
1168
1169 assert_eq!(
1171 transport
1172 .try_write_frame(b"hello world")
1173 .unwrap_err()
1174 .kind(),
1175 io::ErrorKind::WouldBlock
1176 );
1177 }
1178
1179 #[test]
1180 fn try_write_frame_should_return_error_if_encountered_error_with_writing_bytes() {
1181 let mut transport = FramedTransport::new(
1182 TestTransport {
1183 f_try_write: Box::new(|_| Err(io::Error::from(io::ErrorKind::NotConnected))),
1184 f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
1185 ..Default::default()
1186 },
1187 Box::new(OkCodec),
1188 );
1189 assert_eq!(
1190 transport
1191 .try_write_frame(b"hello world")
1192 .unwrap_err()
1193 .kind(),
1194 io::ErrorKind::NotConnected
1195 );
1196 }
1197
1198 #[test]
1199 fn try_write_frame_should_return_error_if_encountered_error_during_encode() {
1200 let mut transport = FramedTransport::new(
1201 TestTransport {
1202 f_try_write: Box::new(|buf| Ok(buf.len())),
1203 f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
1204 ..Default::default()
1205 },
1206 Box::new(ErrCodec),
1207 );
1208 assert_eq!(
1209 transport
1210 .try_write_frame(b"hello world")
1211 .unwrap_err()
1212 .kind(),
1213 io::ErrorKind::Other
1214 );
1215 }
1216
1217 #[test]
1218 fn try_write_frame_should_write_entire_frame_if_possible() {
1219 let (tx, rx) = std::sync::mpsc::sync_channel(1);
1220 let mut transport = FramedTransport::new(
1221 TestTransport {
1222 f_try_write: Box::new(move |buf| {
1223 let len = buf.len();
1224 tx.send(buf.to_vec()).unwrap();
1225 Ok(len)
1226 }),
1227 f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
1228 ..Default::default()
1229 },
1230 Box::new(OkCodec),
1231 );
1232
1233 transport.try_write_frame(b"hello world").unwrap();
1234
1235 assert_eq!(
1237 rx.try_recv().unwrap(),
1238 [11u64.to_be_bytes().as_slice(), b"hello world".as_slice()].concat()
1239 );
1240 }
1241
1242 #[test]
1243 fn try_write_frame_should_write_any_prior_queued_bytes_before_writing_next_frame() {
1244 const STEP_SIZE: usize = Frame::HEADER_SIZE + 5;
1245 let (tx, rx) = std::sync::mpsc::sync_channel(10);
1246 let mut transport = FramedTransport::new(
1247 TestTransport {
1248 f_try_write: Box::new(move |buf| {
1249 static mut CNT: usize = 0;
1250 unsafe {
1251 CNT += 1;
1252 if CNT == 2 {
1253 Err(io::Error::from(io::ErrorKind::WouldBlock))
1254 } else {
1255 let len = std::cmp::min(STEP_SIZE, buf.len());
1256 tx.send(buf[..len].to_vec()).unwrap();
1257 Ok(len)
1258 }
1259 }
1260 }),
1261 f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
1262 ..Default::default()
1263 },
1264 Box::new(OkCodec),
1265 );
1266
1267 assert_eq!(
1269 transport
1270 .try_write_frame(b"hello world")
1271 .unwrap_err()
1272 .kind(),
1273 io::ErrorKind::WouldBlock
1274 );
1275
1276 assert_eq!(
1278 rx.try_recv().unwrap(),
1279 [11u64.to_be_bytes().as_slice(), b"hello".as_slice()].concat()
1280 );
1281 assert_eq!(
1282 rx.try_recv().unwrap_err(),
1283 std::sync::mpsc::TryRecvError::Empty
1284 );
1285
1286 transport.try_write_frame(b"test").unwrap();
1288 assert_eq!(
1289 rx.try_recv().unwrap(),
1290 [b' ', b'w', b'o', b'r', b'l', b'd', 0, 0, 0, 0, 0, 0, 0]
1291 );
1292 assert_eq!(rx.try_recv().unwrap(), [4, b't', b'e', b's', b't']);
1293 assert_eq!(
1294 rx.try_recv().unwrap_err(),
1295 std::sync::mpsc::TryRecvError::Empty
1296 );
1297 }
1298
1299 #[test]
1300 fn try_flush_should_return_error_if_try_write_fails() {
1301 let mut transport = FramedTransport::new(
1302 TestTransport {
1303 f_try_write: Box::new(|_| Err(io::Error::from(io::ErrorKind::NotConnected))),
1304 f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
1305 ..Default::default()
1306 },
1307 Box::new(OkCodec),
1308 );
1309
1310 transport.outgoing.put_slice(b"hello world");
1312
1313 assert_eq!(
1315 transport.try_flush().unwrap_err().kind(),
1316 io::ErrorKind::NotConnected
1317 );
1318 }
1319
1320 #[test]
1321 fn try_flush_should_return_error_if_try_write_returns_0_bytes_written() {
1322 let mut transport = FramedTransport::new(
1323 TestTransport {
1324 f_try_write: Box::new(|_| Ok(0)),
1325 f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
1326 ..Default::default()
1327 },
1328 Box::new(OkCodec),
1329 );
1330
1331 transport.outgoing.put_slice(b"hello world");
1333
1334 assert_eq!(
1336 transport.try_flush().unwrap_err().kind(),
1337 io::ErrorKind::WriteZero
1338 );
1339 }
1340
1341 #[test]
1342 fn try_flush_should_be_noop_if_nothing_to_flush() {
1343 let mut transport = FramedTransport::new(
1344 TestTransport {
1345 f_try_write: Box::new(|_| Err(io::Error::from(io::ErrorKind::NotConnected))),
1346 f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
1347 ..Default::default()
1348 },
1349 Box::new(OkCodec),
1350 );
1351
1352 transport.try_flush().unwrap();
1354 }
1355
1356 #[test]
1357 fn try_flush_should_continually_call_try_write_until_outgoing_buffer_is_empty() {
1358 const STEP_SIZE: usize = 5;
1359 let (tx, rx) = std::sync::mpsc::sync_channel(10);
1360 let mut transport = FramedTransport::new(
1361 TestTransport {
1362 f_try_write: Box::new(move |buf| {
1363 let len = std::cmp::min(STEP_SIZE, buf.len());
1364 tx.send(buf[..len].to_vec()).unwrap();
1365 Ok(len)
1366 }),
1367 f_ready: Box::new(|_| Ok(Ready::WRITABLE)),
1368 ..Default::default()
1369 },
1370 Box::new(OkCodec),
1371 );
1372
1373 transport.outgoing.put_slice(b"hello world");
1375
1376 transport.try_flush().unwrap();
1378
1379 assert_eq!(rx.try_recv().unwrap(), b"hello".as_slice());
1381 assert_eq!(rx.try_recv().unwrap(), b" worl".as_slice());
1382 assert_eq!(rx.try_recv().unwrap(), b"d".as_slice());
1383 assert_eq!(
1384 rx.try_recv().unwrap_err(),
1385 std::sync::mpsc::TryRecvError::Empty
1386 );
1387 }
1388
1389 #[inline]
1390 async fn test_synchronize_stats(
1391 transport: &mut FramedTransport<InmemoryTransport>,
1392 sent_cnt: u64,
1393 received_cnt: u64,
1394 available_cnt: u64,
1395 expected_sent_cnt: u64,
1396 expected_received_cnt: u64,
1397 expected_available_cnt: u64,
1398 ) {
1399 transport
1402 .write_frame_for(&(sent_cnt, received_cnt, available_cnt))
1403 .await
1404 .unwrap();
1405
1406 let (sent, received, available) = transport
1408 .read_frame_as::<(u64, u64, u64)>()
1409 .await
1410 .unwrap()
1411 .unwrap();
1412 assert_eq!(sent, expected_sent_cnt, "Wrong sent cnt");
1413 assert_eq!(received, expected_received_cnt, "Wrong received cnt");
1414 assert_eq!(available, expected_available_cnt, "Wrong available cnt");
1415 }
1416
1417 #[test(tokio::test)]
1418 async fn synchronize_should_resend_no_frames_if_other_side_claims_it_has_more_than_us() {
1419 let (mut t1, mut t2) = FramedTransport::pair(100);
1420
1421 t2.backup.push_frame(Frame::new(b"hello world"));
1423 t2.backup.increment_sent_cnt();
1424
1425 let _task = tokio::spawn(async move {
1429 t2.synchronize().await.unwrap();
1430 t2.write_frame(Frame::new(b"done")).await.unwrap();
1431 t2
1432 });
1433
1434 test_synchronize_stats(&mut t1, 0, 2, 0, 1, 0, 1).await;
1437
1438 assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"done");
1440 }
1441
1442 #[test(tokio::test)]
1443 async fn synchronize_should_resend_no_frames_if_none_missing_on_other_side() {
1444 let (mut t1, mut t2) = FramedTransport::pair(100);
1445
1446 t2.backup.push_frame(Frame::new(b"hello world"));
1448 t2.backup.increment_sent_cnt();
1449
1450 let _task = tokio::spawn(async move {
1454 t2.synchronize().await.unwrap();
1455 t2.write_frame(Frame::new(b"done")).await.unwrap();
1456 t2
1457 });
1458
1459 test_synchronize_stats(&mut t1, 0, 1, 0, 1, 0, 1).await;
1462
1463 assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"done");
1465 }
1466
1467 #[test(tokio::test)]
1468 async fn synchronize_should_resend_some_frames_if_some_missing_on_other_side() {
1469 let (mut t1, mut t2) = FramedTransport::pair(100);
1470
1471 t2.backup.push_frame(Frame::new(b"hello"));
1473 t2.backup.push_frame(Frame::new(b"world"));
1474 t2.backup.increment_sent_cnt();
1475 t2.backup.increment_sent_cnt();
1476
1477 let _task = tokio::spawn(async move {
1481 t2.synchronize().await.unwrap();
1482 t2.write_frame(Frame::new(b"done")).await.unwrap();
1483 t2
1484 });
1485
1486 test_synchronize_stats(&mut t1, 0, 1, 0, 2, 0, 2).await;
1489
1490 assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"world");
1492 assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"done");
1493 }
1494
1495 #[test(tokio::test)]
1496 async fn synchronize_should_resend_all_frames_if_all_missing_on_other_side() {
1497 let (mut t1, mut t2) = FramedTransport::pair(100);
1498
1499 t2.backup.push_frame(Frame::new(b"hello"));
1501 t2.backup.push_frame(Frame::new(b"world"));
1502 t2.backup.increment_sent_cnt();
1503 t2.backup.increment_sent_cnt();
1504
1505 let _task = tokio::spawn(async move {
1509 t2.synchronize().await.unwrap();
1510 t2.write_frame(Frame::new(b"done")).await.unwrap();
1511 t2
1512 });
1513
1514 test_synchronize_stats(&mut t1, 0, 0, 0, 2, 0, 2).await;
1517
1518 assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"hello");
1520 assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"world");
1521 assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"done");
1522 }
1523
1524 #[test(tokio::test)]
1525 async fn synchronize_should_resend_available_frames_if_more_than_available_missing_on_other_side(
1526 ) {
1527 let (mut t1, mut t2) = FramedTransport::pair(100);
1528
1529 t2.backup.push_frame(Frame::new(b"hello"));
1533 t2.backup.push_frame(Frame::new(b"world"));
1534 t2.backup.increment_sent_cnt();
1535 t2.backup.increment_sent_cnt();
1536 t2.backup.increment_sent_cnt();
1537
1538 let _task = tokio::spawn(async move {
1542 t2.synchronize().await.unwrap();
1543 t2.write_frame(Frame::new(b"done")).await.unwrap();
1544 t2
1545 });
1546
1547 test_synchronize_stats(&mut t1, 0, 0, 0, 3, 0, 2).await;
1550
1551 assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"hello");
1553 assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"world");
1554 assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"done");
1555 }
1556
1557 #[test(tokio::test)]
1558 async fn synchronize_should_receive_no_frames_if_other_side_claims_it_has_more_than_us() {
1559 let (mut t1, mut t2) = FramedTransport::pair(100);
1560
1561 t2.backup.increment_received_cnt();
1563
1564 let _task = tokio::spawn(async move {
1568 t2.synchronize().await.unwrap();
1569 t2.write_frame(Frame::new(b"done")).await.unwrap();
1570 t2
1571 });
1572
1573 test_synchronize_stats(&mut t1, 0, 0, 0, 0, 1, 0).await;
1576
1577 assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"done");
1579 }
1580
1581 #[test(tokio::test)]
1582 async fn synchronize_should_receive_no_frames_if_none_missing_from_other_side() {
1583 let (mut t1, mut t2) = FramedTransport::pair(100);
1584
1585 t2.backup.increment_received_cnt();
1587
1588 let _task = tokio::spawn(async move {
1592 t2.synchronize().await.unwrap();
1593 t2.write_frame(Frame::new(b"done")).await.unwrap();
1594 t2
1595 });
1596
1597 test_synchronize_stats(&mut t1, 1, 0, 1, 0, 1, 0).await;
1600
1601 assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"done");
1603 }
1604
1605 #[test(tokio::test)]
1606 async fn synchronize_should_receive_some_frames_if_some_missing_from_other_side() {
1607 let (mut t1, mut t2) = FramedTransport::pair(100);
1608
1609 t2.backup.increment_received_cnt();
1611
1612 let task = tokio::spawn(async move {
1616 t2.synchronize().await.unwrap();
1617 t2.write_frame(Frame::new(b"done")).await.unwrap();
1618 t2
1619 });
1620
1621 test_synchronize_stats(&mut t1, 2, 0, 2, 0, 1, 0).await;
1624
1625 t1.write_frame(Frame::new(b"hello")).await.unwrap();
1627
1628 assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"done");
1630
1631 drop(t1);
1633
1634 let mut t2 = task.await.unwrap();
1636 assert_eq!(t2.read_frame().await.unwrap().unwrap(), b"hello");
1637 assert_eq!(t2.read_frame().await.unwrap(), None);
1638 }
1639
1640 #[test(tokio::test)]
1641 async fn synchronize_should_receive_all_frames_if_all_missing_from_other_side() {
1642 let (mut t1, mut t2) = FramedTransport::pair(100);
1643
1644 let task = tokio::spawn(async move {
1648 t2.synchronize().await.unwrap();
1649 t2.write_frame(Frame::new(b"done")).await.unwrap();
1650 t2
1651 });
1652
1653 test_synchronize_stats(&mut t1, 2, 0, 2, 0, 0, 0).await;
1656
1657 t1.write_frame(Frame::new(b"hello")).await.unwrap();
1659 t1.write_frame(Frame::new(b"world")).await.unwrap();
1660
1661 assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"done");
1663
1664 drop(t1);
1666
1667 let mut t2 = task.await.unwrap();
1669 assert_eq!(t2.read_frame().await.unwrap().unwrap(), b"hello");
1670 assert_eq!(t2.read_frame().await.unwrap().unwrap(), b"world");
1671 assert_eq!(t2.read_frame().await.unwrap(), None);
1672 }
1673
1674 #[test(tokio::test)]
1675 async fn synchronize_should_receive_all_frames_if_more_than_all_missing_from_other_side() {
1676 let (mut t1, mut t2) = FramedTransport::pair(100);
1677
1678 let task = tokio::spawn(async move {
1682 t2.synchronize().await.unwrap();
1683 t2.write_frame(Frame::new(b"done")).await.unwrap();
1684 t2
1685 });
1686
1687 test_synchronize_stats(&mut t1, 2, 0, 2, 0, 0, 0).await;
1690
1691 t1.write_frame(Frame::new(b"hello")).await.unwrap();
1693 t1.write_frame(Frame::new(b"world")).await.unwrap();
1694
1695 assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"done");
1697
1698 drop(t1);
1700
1701 let mut t2 = task.await.unwrap();
1703 assert_eq!(t2.read_frame().await.unwrap().unwrap(), b"hello");
1704 assert_eq!(t2.read_frame().await.unwrap().unwrap(), b"world");
1705 assert_eq!(t2.read_frame().await.unwrap(), None);
1706 }
1707
1708 #[test(tokio::test)]
1709 async fn synchronize_should_fail_if_connection_terminated_before_receiving_missing_frames() {
1710 let (mut t1, mut t2) = FramedTransport::pair(100);
1711
1712 let task = tokio::spawn(async move {
1716 t2.synchronize().await.unwrap();
1717 t2.write_frame(Frame::new(b"done")).await.unwrap();
1718 t2
1719 });
1720
1721 test_synchronize_stats(&mut t1, 2, 0, 2, 0, 0, 0).await;
1724
1725 t1.write_frame(Frame::new(b"hello")).await.unwrap();
1727
1728 drop(t1);
1730
1731 task.await.unwrap_err();
1733 }
1734
1735 #[test(tokio::test)]
1736 async fn synchronize_should_fail_if_connection_terminated_while_waiting_for_frame_stats() {
1737 let (t1, mut t2) = FramedTransport::pair(100);
1738
1739 let task = tokio::spawn(async move {
1743 t2.synchronize().await.unwrap();
1744 t2.write_frame(Frame::new(b"done")).await.unwrap();
1745 t2
1746 });
1747
1748 drop(t1);
1750
1751 task.await.unwrap_err();
1753 }
1754
1755 #[test(tokio::test)]
1756 async fn synchronize_should_clear_any_prexisting_incoming_and_outgoing_data() {
1757 let (mut t1, mut t2) = FramedTransport::pair(100);
1758
1759 Frame::new(b"bad incoming").write(&mut t2.incoming);
1761 Frame::new(b"bad outgoing").write(&mut t2.outgoing);
1762
1763 t2.backup.push_frame(Frame::new(b"hello"));
1765 t2.backup.push_frame(Frame::new(b"world"));
1766 t2.backup.increment_sent_cnt();
1767 t2.backup.increment_sent_cnt();
1768
1769 let task = tokio::spawn(async move {
1773 t2.synchronize().await.unwrap();
1774 t2.write_frame(Frame::new(b"done")).await.unwrap();
1775 t2
1776 });
1777
1778 test_synchronize_stats(&mut t1, 2, 0, 2, 2, 0, 2).await;
1781
1782 t1.write_frame(Frame::new(b"one")).await.unwrap();
1784 t1.write_frame(Frame::new(b"two")).await.unwrap();
1785
1786 assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"hello");
1788 assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"world");
1789 assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"done");
1790
1791 drop(t1);
1793
1794 let mut t2 = task.await.unwrap();
1796 assert_eq!(t2.read_frame().await.unwrap().unwrap(), b"one");
1797 assert_eq!(t2.read_frame().await.unwrap().unwrap(), b"two");
1798 assert_eq!(t2.read_frame().await.unwrap(), None);
1799 }
1800
1801 #[test(tokio::test)]
1802 async fn synchronize_should_not_increment_the_sent_frames_or_store_replayed_frames_in_the_backup(
1803 ) {
1804 let (mut t1, mut t2) = FramedTransport::pair(100);
1805
1806 t2.backup.push_frame(Frame::new(b"hello"));
1808 t2.backup.push_frame(Frame::new(b"world"));
1809 t2.backup.increment_sent_cnt();
1810 t2.backup.increment_sent_cnt();
1811
1812 let task = tokio::spawn(async move {
1816 t2.synchronize().await.unwrap();
1817
1818 t2.backup.freeze();
1819 t2.write_frame(Frame::new(b"done")).await.unwrap();
1820 t2.backup.unfreeze();
1821
1822 t2
1823 });
1824
1825 test_synchronize_stats(&mut t1, 0, 0, 0, 2, 0, 2).await;
1828
1829 assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"hello");
1831 assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"world");
1832 assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"done");
1833
1834 drop(t1);
1836
1837 let t2 = task.await.unwrap();
1839 assert_eq!(t2.backup.sent_cnt(), 2, "Wrong sent cnt");
1840 assert_eq!(t2.backup.received_cnt(), 0, "Wrong received cnt");
1841 assert_eq!(t2.backup.frame_cnt(), 2, "Wrong frame cnt");
1842 }
1843
1844 #[test(tokio::test)]
1845 async fn synchronize_should_update_the_backup_received_cnt_to_match_other_side_sent() {
1846 let (mut t1, mut t2) = FramedTransport::pair(100);
1847
1848 let task = tokio::spawn(async move {
1852 t2.synchronize().await.unwrap();
1853
1854 t2.backup.freeze();
1855 t2.write_frame(Frame::new(b"done")).await.unwrap();
1856 t2.backup.unfreeze();
1857
1858 t2
1859 });
1860
1861 test_synchronize_stats(&mut t1, 2, 0, 1, 0, 0, 0).await;
1864
1865 t1.write_frame(Frame::new(b"hello")).await.unwrap();
1867
1868 assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"done");
1870
1871 drop(t1);
1873
1874 let t2 = task.await.unwrap();
1876 assert_eq!(t2.backup.sent_cnt(), 0, "Wrong sent cnt");
1877 assert_eq!(t2.backup.received_cnt(), 2, "Wrong received cnt");
1878 assert_eq!(t2.backup.frame_cnt(), 0, "Wrong frame cnt");
1879 }
1880
1881 #[test(tokio::test)]
1882 async fn synchronize_should_work_even_if_codec_changes_between_attempts() {
1883 let (mut t1, _t1_other) = FramedTransport::pair(100);
1884 let (mut t2, _t2_other) = FramedTransport::pair(100);
1885
1886 t1.write_frame(Frame::new(b"hello")).await.unwrap();
1888 t1.write_frame(Frame::new(b"world")).await.unwrap();
1889 t2.write_frame(Frame::new(b"foo")).await.unwrap();
1890 t2.write_frame(Frame::new(b"bar")).await.unwrap();
1891
1892 drop(_t1_other);
1894 drop(_t2_other);
1895 t1.link(&mut t2, 100);
1896 let codec = EncryptionCodec::new_xchacha20poly1305(Default::default());
1897 t1.codec = Box::new(codec.clone());
1898 t2.codec = Box::new(codec);
1899
1900 let task = tokio::spawn(async move {
1902 t2.synchronize().await.unwrap();
1903 t2
1904 });
1905
1906 t1.synchronize().await.unwrap();
1907
1908 let mut t2 = task.await.unwrap();
1910 assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"foo");
1911 assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"bar");
1912 assert_eq!(t2.read_frame().await.unwrap().unwrap(), b"hello");
1913 assert_eq!(t2.read_frame().await.unwrap().unwrap(), b"world");
1914 }
1915
1916 #[test(tokio::test)]
1917 async fn handshake_should_configure_transports_with_matching_codec() {
1918 let (mut t1, mut t2) = FramedTransport::test_pair(100);
1919
1920 let task = tokio::spawn(async move {
1923 t2.server_handshake().await.unwrap();
1925
1926 let frame = t2.read_frame().await.unwrap().unwrap();
1928 t2.write_frame(frame).await.unwrap();
1929 });
1930
1931 t1.client_handshake().await.unwrap();
1932
1933 t1.write_frame(b"hello world").await.unwrap();
1935 assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"hello world");
1936
1937 task.await.unwrap();
1939 }
1940
1941 #[test(tokio::test)]
1942 async fn handshake_failing_should_ensure_existing_codec_remains() {
1943 let (mut t1, t2) = FramedTransport::test_pair(100);
1944
1945 t1.set_codec(Box::new(CustomCodec));
1947
1948 drop(t2);
1950
1951 t1.client_handshake().await.unwrap_err();
1953
1954 assert_eq!(t1.codec.encode(Frame::new(b"test")).unwrap(), b"encode");
1956 assert_eq!(t1.codec.decode(Frame::new(b"test")).unwrap(), b"decode");
1957 }
1958
1959 #[test(tokio::test)]
1960 async fn handshake_should_clear_any_intermittent_buffer_contents_prior_to_handshake_failing() {
1961 let (mut t1, t2) = FramedTransport::test_pair(100);
1962
1963 t1.set_codec(Box::new(CustomCodec));
1965
1966 drop(t2);
1968
1969 t1.incoming.extend_from_slice(b"garbage in");
1971 t1.outgoing.extend_from_slice(b"garbage out");
1972
1973 t1.client_handshake().await.unwrap_err();
1975
1976 assert!(t1.incoming.is_empty());
1978 assert!(t1.outgoing.is_empty());
1979 }
1980
1981 #[test(tokio::test)]
1982 async fn handshake_should_clear_any_intermittent_buffer_contents_prior_to_handshake_succeeding()
1983 {
1984 let (mut t1, mut t2) = FramedTransport::test_pair(100);
1985
1986 let task = tokio::spawn(async move {
1989 t2.server_handshake().await.unwrap();
1991
1992 let frame = t2.read_frame().await.unwrap().unwrap();
1994 t2.write_frame(frame).await.unwrap();
1995 });
1996
1997 t1.incoming.extend_from_slice(b"garbage in");
1999 t1.outgoing.extend_from_slice(b"garbage out");
2000
2001 t1.client_handshake().await.unwrap();
2002
2003 t1.write_frame(b"hello world").await.unwrap();
2005 assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"hello world");
2006
2007 task.await.unwrap();
2009
2010 assert!(t1.incoming.is_empty());
2012 assert!(t1.outgoing.is_empty());
2013 }
2014
2015 #[test(tokio::test)]
2016 async fn handshake_for_client_should_fail_if_receives_unexpected_frame_instead_of_options() {
2017 let (mut t1, mut t2) = FramedTransport::test_pair(100);
2018
2019 let task = tokio::spawn(async move {
2022 t2.write_frame(b"not a valid frame for handshake")
2023 .await
2024 .unwrap();
2025 });
2026
2027 let err = t1.client_handshake().await.unwrap_err();
2029 assert_eq!(err.kind(), io::ErrorKind::InvalidData);
2030
2031 task.await.unwrap();
2033 }
2034
2035 #[test(tokio::test)]
2036 async fn handshake_for_client_should_fail_unable_to_send_codec_choice_to_other_side() {
2037 let (mut t1, mut t2) = FramedTransport::test_pair(100);
2038
2039 #[derive(Debug, Serialize, Deserialize)]
2040 struct Options {
2041 compression_types: Vec<CompressionType>,
2042 encryption_types: Vec<EncryptionType>,
2043 }
2044
2045 let task = tokio::spawn(async move {
2048 t2.write_frame_for(&Options {
2050 compression_types: Vec::new(),
2051 encryption_types: Vec::new(),
2052 })
2053 .await
2054 .unwrap();
2055 });
2056
2057 let err = t1.client_handshake().await.unwrap_err();
2059 assert_eq!(err.kind(), io::ErrorKind::WriteZero);
2060
2061 task.await.unwrap();
2063 }
2064
2065 #[test(tokio::test)]
2066 async fn handshake_for_client_should_fail_if_unable_to_receive_key_exchange_data_from_other_side(
2067 ) {
2068 #[derive(Debug, Serialize, Deserialize)]
2069 struct Options {
2070 compression_types: Vec<CompressionType>,
2071 encryption_types: Vec<EncryptionType>,
2072 }
2073
2074 let (mut t1, mut t2) = FramedTransport::test_pair(100);
2075
2076 t2.write_frame_for(&Options {
2078 compression_types: CompressionType::known_variants().to_vec(),
2079 encryption_types: EncryptionType::known_variants().to_vec(),
2080 })
2081 .await
2082 .unwrap();
2083
2084 t2.write_frame(b"not valid key exchange data")
2085 .await
2086 .unwrap();
2087
2088 let err = t1.client_handshake().await.unwrap_err();
2090 assert_eq!(err.kind(), io::ErrorKind::InvalidData);
2091 }
2092
2093 #[test(tokio::test)]
2094 async fn handshake_for_server_should_fail_if_receives_unexpected_frame_instead_of_choice() {
2095 let (mut t1, mut t2) = FramedTransport::test_pair(100);
2096
2097 let task = tokio::spawn(async move {
2100 t2.write_frame(b"not a valid frame for handshake")
2101 .await
2102 .unwrap();
2103 });
2104
2105 let err = t1.server_handshake().await.unwrap_err();
2107 assert_eq!(err.kind(), io::ErrorKind::InvalidData);
2108
2109 task.await.unwrap();
2111 }
2112
2113 #[test(tokio::test)]
2114 async fn handshake_for_server_should_fail_unable_to_send_codec_options_to_other_side() {
2115 let (mut t1, t2) = FramedTransport::test_pair(100);
2116
2117 drop(t2);
2119
2120 let err = t1.server_handshake().await.unwrap_err();
2122 assert_eq!(err.kind(), io::ErrorKind::WriteZero);
2123 }
2124
2125 #[test(tokio::test)]
2126 async fn handshake_for_server_should_fail_if_selected_codec_choice_uses_an_unknown_compression_type(
2127 ) {
2128 #[derive(Debug, Serialize, Deserialize)]
2129 struct Choice {
2130 compression_level: Option<CompressionLevel>,
2131 compression_type: Option<CompressionType>,
2132 encryption_type: Option<EncryptionType>,
2133 }
2134
2135 let (mut t1, mut t2) = FramedTransport::test_pair(100);
2136
2137 t2.write_frame_for(&Choice {
2139 compression_level: None,
2140 compression_type: Some(CompressionType::Unknown),
2141 encryption_type: None,
2142 })
2143 .await
2144 .unwrap();
2145
2146 let err = t1.server_handshake().await.unwrap_err();
2148 assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
2149 }
2150
2151 #[test(tokio::test)]
2152 async fn handshake_for_server_should_fail_if_selected_codec_choice_uses_an_unknown_encryption_type(
2153 ) {
2154 #[derive(Debug, Serialize, Deserialize)]
2155 struct Choice {
2156 compression_level: Option<CompressionLevel>,
2157 compression_type: Option<CompressionType>,
2158 encryption_type: Option<EncryptionType>,
2159 }
2160
2161 let (mut t1, mut t2) = FramedTransport::test_pair(100);
2162
2163 t2.write_frame_for(&Choice {
2165 compression_level: None,
2166 compression_type: None,
2167 encryption_type: Some(EncryptionType::Unknown),
2168 })
2169 .await
2170 .unwrap();
2171
2172 let err = t1.server_handshake().await.unwrap_err();
2174 assert_eq!(err.kind(), io::ErrorKind::InvalidInput);
2175 }
2176
2177 #[test(tokio::test)]
2178 async fn handshake_for_server_should_fail_if_unable_to_receive_key_exchange_data_from_other_side(
2179 ) {
2180 #[derive(Debug, Serialize, Deserialize)]
2181 struct Choice {
2182 compression_level: Option<CompressionLevel>,
2183 compression_type: Option<CompressionType>,
2184 encryption_type: Option<EncryptionType>,
2185 }
2186
2187 let (mut t1, mut t2) = FramedTransport::test_pair(100);
2188
2189 t2.write_frame_for(&Choice {
2191 compression_level: None,
2192 compression_type: None,
2193 encryption_type: Some(EncryptionType::XChaCha20Poly1305),
2194 })
2195 .await
2196 .unwrap();
2197
2198 t2.write_frame(b"not valid key exchange data")
2199 .await
2200 .unwrap();
2201
2202 let err = t1.server_handshake().await.unwrap_err();
2204 assert_eq!(err.kind(), io::ErrorKind::InvalidData);
2205 }
2206
2207 #[test(tokio::test)]
2208 async fn exchange_keys_should_fail_if_unable_to_send_exchange_data_to_other_side() {
2209 let (mut t1, t2) = FramedTransport::test_pair(100);
2210
2211 drop(t2);
2213
2214 assert_eq!(
2216 t1.exchange_keys().await.unwrap_err().kind(),
2217 io::ErrorKind::WriteZero
2218 );
2219 }
2220
2221 #[test(tokio::test)]
2222 async fn exchange_keys_should_fail_if_received_invalid_exchange_data() {
2223 let (mut t1, mut t2) = FramedTransport::test_pair(100);
2224
2225 t2.write_frame(b"some invalid frame").await.unwrap();
2227
2228 assert_eq!(
2230 t1.exchange_keys().await.unwrap_err().kind(),
2231 io::ErrorKind::InvalidData
2232 );
2233 }
2234
2235 #[test(tokio::test)]
2236 async fn exchange_keys_should_return_shared_secret_key_if_successful() {
2237 let (mut t1, mut t2) = FramedTransport::test_pair(100);
2238
2239 let task = tokio::spawn(async move { t2.exchange_keys().await.unwrap() });
2241
2242 let key = t1.exchange_keys().await.unwrap();
2244
2245 assert_eq!(key, task.await.unwrap());
2247 }
2248}