1use std::io;
2use std::ops::{Deref, DerefMut};
3
4use async_trait::async_trait;
5use distant_auth::{AuthHandler, Authenticate, Verifier};
6use log::*;
7use serde::{Deserialize, Serialize};
8use tokio::sync::oneshot;
9
10#[cfg(test)]
11use crate::common::InmemoryTransport;
12use crate::common::{
13 Backup, FramedTransport, HeapSecretKey, Keychain, KeychainResult, Reconnectable, Transport,
14 TransportExt, Version,
15};
16
17pub type ConnectionId = u32;
19
20#[derive(Debug)]
22pub enum Connection<T> {
23 Client {
25 id: ConnectionId,
27
28 reauth_otp: HeapSecretKey,
30
31 transport: FramedTransport<T>,
33 },
34
35 Server {
37 id: ConnectionId,
39
40 tx: oneshot::Sender<Backup>,
42
43 transport: FramedTransport<T>,
45 },
46}
47
48impl<T> Deref for Connection<T> {
49 type Target = FramedTransport<T>;
50
51 fn deref(&self) -> &Self::Target {
52 match self {
53 Self::Client { transport, .. } => transport,
54 Self::Server { transport, .. } => transport,
55 }
56 }
57}
58
59impl<T> DerefMut for Connection<T> {
60 fn deref_mut(&mut self) -> &mut Self::Target {
61 match self {
62 Self::Client { transport, .. } => transport,
63 Self::Server { transport, .. } => transport,
64 }
65 }
66}
67
68impl<T> Drop for Connection<T> {
69 fn drop(&mut self) {
72 match self {
73 Self::Client { .. } => (),
74 Self::Server { tx, transport, .. } => {
75 let backup = std::mem::take(&mut transport.backup);
78 let tx = std::mem::replace(tx, oneshot::channel().0);
79 let _ = tx.send(backup);
80 }
81 }
82 }
83}
84
85#[async_trait]
86impl<T> Reconnectable for Connection<T>
87where
88 T: Transport,
89{
90 async fn reconnect(&mut self) -> io::Result<()> {
105 async fn reconnect_client<T: Transport>(
106 id: ConnectionId,
107 reauth_otp: HeapSecretKey,
108 transport: &mut FramedTransport<T>,
109 ) -> io::Result<(ConnectionId, HeapSecretKey)> {
110 debug!("[Conn {id}] Re-establishing connection");
112 Reconnectable::reconnect(transport).await?;
113
114 debug!("[Conn {id}] Waiting for server version");
120 if transport.as_mut_inner().read_exact(&mut [0u8; 24]).await? != 24 {
121 return Err(io::Error::new(
122 io::ErrorKind::InvalidData,
123 "Wrong version byte len received",
124 ));
125 }
126
127 debug!("[Conn {id}] Performing handshake");
129 transport.client_handshake().await?;
130
131 debug!("[Conn {id}] Performing re-authentication");
133 transport
134 .write_frame_for(&ConnectType::Reconnect {
135 id,
136 otp: reauth_otp.unprotected_into_bytes(),
137 })
138 .await?;
139
140 debug!("[Conn {id}] Deriving future OTP for reauthentication");
142 let reauth_otp = transport.exchange_keys().await?.into_heap_secret_key();
143
144 Ok((id, reauth_otp))
145 }
146
147 match self {
148 Self::Client {
149 id,
150 transport,
151 reauth_otp,
152 } => {
153 let (new_id, new_reauth_otp) = {
156 transport.backup.freeze();
157 let result = reconnect_client(*id, reauth_otp.clone(), transport).await;
158 transport.backup.unfreeze();
159 result?
160 };
161
162 debug!("[Conn {id}] Synchronizing frame state");
164 transport.synchronize().await?;
165
166 info!("[Conn {id}] Reconnect completed successfully! Assigning new id {new_id}");
168 *id = new_id;
169 *reauth_otp = new_reauth_otp;
170
171 Ok(())
172 }
173
174 Self::Server { .. } => Err(io::Error::new(
175 io::ErrorKind::Unsupported,
176 "Server connection cannot reconnect",
177 )),
178 }
179 }
180}
181
182#[derive(Debug, Serialize, Deserialize)]
184enum ConnectType {
185 Connect,
187
188 Reconnect {
191 id: ConnectionId,
193
194 #[serde(with = "serde_bytes")]
196 otp: Vec<u8>,
197 },
198}
199
200impl<T> Connection<T>
201where
202 T: Transport,
203{
204 pub async fn client<H: AuthHandler + Send>(
213 transport: T,
214 handler: H,
215 version: Version,
216 ) -> io::Result<Self> {
217 let id: ConnectionId = rand::random();
218
219 debug!("[Conn {id}] Waiting for server version");
221 let mut version_bytes = [0u8; 24];
222 if transport.read_exact(&mut version_bytes).await? != 24 {
223 return Err(io::Error::new(
224 io::ErrorKind::InvalidData,
225 "Wrong version byte len received",
226 ));
227 }
228
229 let server_version = Version::from_be_bytes(version_bytes);
231 debug!(
232 "[Conn {id}] Checking compatibility between client {version} & server {server_version}"
233 );
234 if !version.is_compatible_with(&server_version) {
235 return Err(io::Error::new(
236 io::ErrorKind::Other,
237 format!(
238 "Client version {version} is incompatible with server version {server_version}"
239 ),
240 ));
241 }
242
243 debug!("[Conn {id}] Performing handshake");
245 let mut transport: FramedTransport<T> =
246 FramedTransport::from_client_handshake(transport).await?;
247
248 debug!("[Conn {id}] Communicating that this is a new connection");
250 transport.write_frame_for(&ConnectType::Connect).await?;
251
252 let id = {
254 debug!("[Conn {id}] Receiving new connection id");
255 let new_id = transport
256 .read_frame_as::<ConnectionId>()
257 .await?
258 .ok_or_else(|| {
259 io::Error::new(io::ErrorKind::Other, "Missing connection id frame")
260 })?;
261 debug!("[Conn {id}] Resetting id to {new_id}");
262 new_id
263 };
264
265 debug!("[Conn {id}] Performing authentication");
267 transport.authenticate(handler).await?;
268
269 debug!("[Conn {id}] Deriving future OTP for reauthentication");
271 let reauth_otp = transport.exchange_keys().await?.into_heap_secret_key();
272
273 info!("[Conn {id}] Connect completed successfully!");
274 Ok(Self::Client {
275 id,
276 reauth_otp,
277 transport,
278 })
279 }
280
281 pub async fn server(
292 transport: T,
293 verifier: &Verifier,
294 keychain: Keychain<oneshot::Receiver<Backup>>,
295 version: Version,
296 ) -> io::Result<Self> {
297 let id: ConnectionId = rand::random();
298
299 debug!("[Conn {id}] Sending version {version}");
301 transport.write_all(&version.to_be_bytes()).await?;
302
303 debug!("[Conn {id}] Performing handshake");
305 let mut transport: FramedTransport<T> =
306 FramedTransport::from_server_handshake(transport).await?;
307
308 debug!("[Conn {id}] Waiting for connection type");
319 let connection_type = transport
320 .read_frame_as::<ConnectType>()
321 .await?
322 .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "Missing connection type frame"))?;
323
324 let (tx, rx) = oneshot::channel();
326
327 let id = match connection_type {
330 ConnectType::Connect => {
331 debug!("[Conn {id}] Telling other side to change connection id");
333 transport.write_frame_for(&id).await?;
334
335 debug!("[Conn {id}] Verifying connection");
337 verifier.verify(&mut transport).await?;
338
339 debug!("[Conn {id}] Deriving future OTP for reauthentication");
341 let reauth_otp = transport.exchange_keys().await?.into_heap_secret_key();
342
343 info!("[Conn {id}] Connect completed successfully!");
345 keychain.insert(id.to_string(), reauth_otp, rx).await;
346
347 id
348 }
349 ConnectType::Reconnect { id: other_id, otp } => {
350 let reauth_otp = HeapSecretKey::from(otp);
351
352 debug!("[Conn {id}] Checking if {other_id} exists and has matching OTP");
353 match keychain
354 .remove_if_has_key(other_id.to_string(), reauth_otp.clone())
355 .await
356 {
357 KeychainResult::Ok(x) => {
358 debug!("[Conn {id}] Reassigning to {other_id}");
360 let id = other_id;
361
362 debug!("[Conn {id}] Acquiring backup for existing connection");
364 let backup = match x.await {
365 Ok(backup) => backup,
366 Err(_) => {
367 warn!("[Conn {id}] Missing backup, will use fresh copy");
368 Backup::new()
369 }
370 };
371
372 macro_rules! unwrap_or_fail {
373 ($action:expr) => {
374 unwrap_or_fail!(backup, $action)
375 };
376 ($backup:expr, $action:expr) => {{
377 match $action {
378 Ok(x) => x,
379 Err(x) => {
380 error!("[Conn {id}] Encountered error, restoring with old backup");
381 let _ = tx.send($backup);
382 keychain.insert(id.to_string(), reauth_otp, rx).await;
383 return Err(x);
384 }
385 }
386 }};
387 }
388
389 debug!("[Conn {id}] Deriving future OTP for reauthentication");
391 let new_reauth_otp =
392 unwrap_or_fail!(transport.exchange_keys().await).into_heap_secret_key();
393
394 debug!("[Conn {id}] Restoring backup");
396 transport.backup = backup;
397
398 debug!("[Conn {id}] Synchronizing frame state");
400 unwrap_or_fail!(transport.backup, transport.synchronize().await);
401
402 info!("[Conn {id}] Reconnect restoration completed successfully!");
404 keychain.insert(id.to_string(), new_reauth_otp, rx).await;
405
406 id
407 }
408 KeychainResult::InvalidPassword => {
409 return Err(io::Error::new(
410 io::ErrorKind::PermissionDenied,
411 "Invalid OTP for reconnect",
412 ));
413 }
414 KeychainResult::InvalidId => {
415 return Err(io::Error::new(
416 io::ErrorKind::PermissionDenied,
417 "Invalid id for reconnect",
418 ));
419 }
420 }
421 }
422 };
423
424 Ok(Self::Server { id, tx, transport })
425 }
426}
427
428#[cfg(test)]
429impl Connection<InmemoryTransport> {
430 pub fn pair(buffer: usize) -> (Self, Self) {
438 let id = rand::random::<ConnectionId>();
439 let (t1, t2) = FramedTransport::pair(buffer);
440
441 let client = Connection::Client {
442 id,
443 reauth_otp: HeapSecretKey::generate(32).unwrap(),
444 transport: t1,
445 };
446
447 let server = Connection::Server {
448 id,
449 tx: oneshot::channel().0,
450 transport: t2,
451 };
452
453 (client, server)
454 }
455}
456
457impl<T> Connection<T> {
458 pub fn id(&self) -> ConnectionId {
460 match self {
461 Self::Client { id, .. } => *id,
462 Self::Server { id, .. } => *id,
463 }
464 }
465}
466
467#[cfg(test)]
468impl<T> Connection<T> {
469 pub fn otp(&self) -> Option<&HeapSecretKey> {
471 match self {
472 Self::Client { reauth_otp, .. } => Some(reauth_otp),
473 Self::Server { .. } => None,
474 }
475 }
476
477 pub fn transport(&self) -> &FramedTransport<T> {
479 match self {
480 Self::Client { transport, .. } => transport,
481 Self::Server { transport, .. } => transport,
482 }
483 }
484
485 pub fn mut_transport(&mut self) -> &mut FramedTransport<T> {
487 match self {
488 Self::Client { transport, .. } => transport,
489 Self::Server { transport, .. } => transport,
490 }
491 }
492}
493
494#[cfg(test)]
495impl<T: Transport> Connection<T> {
496 pub fn test_client(transport: T) -> Self {
497 Self::Client {
498 id: rand::random(),
499 reauth_otp: HeapSecretKey::generate(32).unwrap(),
500 transport: FramedTransport::plain(transport),
501 }
502 }
503}
504
505#[cfg(test)]
506mod tests {
507 use std::sync::Arc;
508
509 use distant_auth::msg::Challenge;
510 use distant_auth::{Authenticator, DummyAuthHandler};
511 use test_log::test;
512
513 use super::*;
514 use crate::common::Frame;
515
516 macro_rules! server_version {
517 () => {
518 Version::new(1, 2, 3)
519 };
520 }
521
522 macro_rules! send_server_version {
523 ($transport:expr, $version:expr) => {{
524 ($transport)
525 .as_mut_inner()
526 .write_all(&$version.to_be_bytes())
527 .await
528 .unwrap();
529 }};
530 ($transport:expr) => {
531 send_server_version!($transport, server_version!());
532 };
533 }
534
535 macro_rules! receive_version {
536 ($transport:expr) => {{
537 let mut bytes = [0u8; 24];
538 assert_eq!(
539 ($transport)
540 .as_mut_inner()
541 .read_exact(&mut bytes)
542 .await
543 .unwrap(),
544 24,
545 "Wrong version len received"
546 );
547 Version::from_be_bytes(bytes)
548 }};
549 }
550
551 #[test(tokio::test)]
552 async fn client_should_fail_when_server_sends_incompatible_version() {
553 let (mut t1, t2) = FramedTransport::pair(100);
554
555 let task = tokio::spawn(async move {
558 Connection::client(t2.into_inner(), DummyAuthHandler, Version::new(1, 2, 3))
559 .await
560 .unwrap()
561 });
562
563 send_server_version!(t1, Version::new(2, 0, 0));
565
566 task.await.unwrap_err();
568 }
569
570 #[test(tokio::test)]
571 async fn client_should_fail_if_codec_handshake_fails() {
572 let (mut t1, t2) = FramedTransport::pair(100);
573
574 let task = tokio::spawn(async move {
577 Connection::client(t2.into_inner(), DummyAuthHandler, server_version!())
578 .await
579 .unwrap()
580 });
581
582 send_server_version!(t1);
584
585 t1.write_frame(Frame::new(b"invalid")).await.unwrap();
587
588 task.await.unwrap_err();
590 }
591
592 #[test(tokio::test)]
593 async fn client_should_fail_if_unable_to_receive_connection_id_from_server() {
594 let (mut t1, t2) = FramedTransport::pair(100);
595
596 let task = tokio::spawn(async move {
599 Connection::client(t2.into_inner(), DummyAuthHandler, server_version!())
600 .await
601 .unwrap()
602 });
603
604 send_server_version!(t1);
606
607 t1.server_handshake().await.unwrap();
609
610 let ct = t1.read_frame_as::<ConnectType>().await.unwrap().unwrap();
612 assert!(
613 matches!(ct, ConnectType::Connect),
614 "Unexpected connect type: {ct:?}"
615 );
616
617 drop(t1);
619
620 task.await.unwrap_err();
622 }
623
624 #[test(tokio::test)]
625 async fn client_should_fail_if_authentication_fails() {
626 let (mut t1, t2) = FramedTransport::pair(100);
627
628 let task = tokio::spawn(async move {
631 Connection::client(t2.into_inner(), DummyAuthHandler, server_version!())
632 .await
633 .unwrap()
634 });
635
636 send_server_version!(t1);
638
639 t1.server_handshake().await.unwrap();
641
642 let ct = t1.read_frame_as::<ConnectType>().await.unwrap().unwrap();
644 assert!(
645 matches!(ct, ConnectType::Connect),
646 "Unexpected connect type: {ct:?}"
647 );
648
649 t1.write_frame_for(&rand::random::<ConnectionId>())
651 .await
652 .unwrap();
653
654 t1.challenge(Challenge {
657 questions: Vec::new(),
658 options: Default::default(),
659 })
660 .await
661 .unwrap_err();
662
663 task.await.unwrap_err();
665 }
666
667 #[test(tokio::test)]
668 async fn client_should_fail_if_unable_to_exchange_otp_for_reauthentication() {
669 let (mut t1, t2) = FramedTransport::pair(100);
670
671 let task = tokio::spawn(async move {
674 Connection::client(t2.into_inner(), DummyAuthHandler, server_version!())
675 .await
676 .unwrap()
677 });
678
679 send_server_version!(t1);
681
682 t1.server_handshake().await.unwrap();
684
685 let ct = t1.read_frame_as::<ConnectType>().await.unwrap().unwrap();
687 assert!(
688 matches!(ct, ConnectType::Connect),
689 "Unexpected connect type: {ct:?}"
690 );
691
692 t1.write_frame_for(&rand::random::<ConnectionId>())
694 .await
695 .unwrap();
696
697 Verifier::none().verify(&mut t1).await.unwrap();
700
701 t1.write_frame(Frame::new(b"invalid")).await.unwrap();
703
704 task.await.unwrap_err();
706 }
707
708 #[test(tokio::test)]
709 async fn client_should_succeed_if_establishes_connection_with_server() {
710 let (mut t1, t2) = FramedTransport::pair(100);
711
712 let task = tokio::spawn(async move {
715 Connection::client(t2.into_inner(), DummyAuthHandler, server_version!())
716 .await
717 .unwrap()
718 });
719
720 send_server_version!(t1);
722
723 t1.server_handshake().await.unwrap();
725
726 let ct = t1.read_frame_as::<ConnectType>().await.unwrap().unwrap();
728 assert!(
729 matches!(ct, ConnectType::Connect),
730 "Unexpected connect type: {ct:?}"
731 );
732
733 t1.write_frame_for(&rand::random::<ConnectionId>())
735 .await
736 .unwrap();
737
738 Verifier::none().verify(&mut t1).await.unwrap();
741
742 let otp = t1.exchange_keys().await.unwrap().into_heap_secret_key();
744
745 let client = task.await.unwrap();
747 assert_eq!(client.otp(), Some(&otp));
748 }
749
750 #[test(tokio::test)]
751 async fn server_should_fail_if_client_drops_due_to_version() {
752 let (mut t1, t2) = FramedTransport::pair(100);
753 let verifier = Verifier::none();
754 let keychain = Keychain::new();
755
756 let task = tokio::spawn(async move {
759 Connection::server(t2.into_inner(), &verifier, keychain, server_version!())
760 .await
761 .unwrap()
762 });
763
764 let _ = receive_version!(t1);
766
767 drop(t1);
769
770 task.await.unwrap_err();
772 }
773
774 #[test(tokio::test)]
775 async fn server_should_fail_if_codec_handshake_fails() {
776 let (mut t1, t2) = FramedTransport::pair(100);
777 let verifier = Verifier::none();
778 let keychain = Keychain::new();
779
780 let task = tokio::spawn(async move {
783 Connection::server(t2.into_inner(), &verifier, keychain, server_version!())
784 .await
785 .unwrap()
786 });
787
788 let _ = receive_version!(t1);
790
791 t1.write_frame(Frame::new(b"invalid")).await.unwrap();
793
794 task.await.unwrap_err();
796 }
797
798 #[test(tokio::test)]
799 async fn server_should_fail_if_unable_to_receive_connect_type() {
800 let (mut t1, t2) = FramedTransport::pair(100);
801 let verifier = Verifier::none();
802 let keychain = Keychain::new();
803
804 let task = tokio::spawn(async move {
807 Connection::server(t2.into_inner(), &verifier, keychain, server_version!())
808 .await
809 .unwrap()
810 });
811
812 let _ = receive_version!(t1);
814
815 t1.client_handshake().await.unwrap();
817
818 t1.write_frame(Frame::new(b"hello")).await.unwrap();
820
821 task.await.unwrap_err();
823 }
824
825 #[test(tokio::test)]
826 async fn server_should_fail_if_unable_to_verify_new_client() {
827 let (mut t1, t2) = FramedTransport::pair(100);
828 let verifier = Verifier::static_key(HeapSecretKey::generate(32).unwrap());
829 let keychain = Keychain::new();
830
831 let task = tokio::spawn(async move {
834 Connection::server(t2.into_inner(), &verifier, keychain, server_version!())
835 .await
836 .unwrap()
837 });
838
839 let _ = receive_version!(t1);
841
842 t1.client_handshake().await.unwrap();
844
845 t1.write_frame_for(&ConnectType::Connect).await.unwrap();
847
848 let _id = t1.read_frame_as::<ConnectionId>().await.unwrap().unwrap();
850
851 t1.authenticate(DummyAuthHandler).await.unwrap_err();
853
854 drop(t1);
857
858 task.await.unwrap_err();
860 }
861
862 #[test(tokio::test)]
863 async fn server_should_fail_if_unable_to_exchange_otp_for_reauthentication_with_new_client() {
864 let (mut t1, t2) = FramedTransport::pair(100);
865 let verifier = Verifier::none();
866 let keychain = Keychain::new();
867
868 let task = tokio::spawn(async move {
871 Connection::server(t2.into_inner(), &verifier, keychain, server_version!())
872 .await
873 .unwrap()
874 });
875
876 let _ = receive_version!(t1);
878
879 t1.client_handshake().await.unwrap();
881
882 t1.write_frame_for(&ConnectType::Connect).await.unwrap();
884
885 let _id = t1.read_frame_as::<ConnectionId>().await.unwrap().unwrap();
887
888 t1.authenticate(DummyAuthHandler).await.unwrap();
890
891 t1.write_frame(Frame::new(b"hello")).await.unwrap();
893
894 task.await.unwrap_err();
896 }
897
898 #[test(tokio::test)]
899 async fn server_should_fail_if_existing_client_id_is_invalid() {
900 let (mut t1, t2) = FramedTransport::pair(100);
901 let verifier = Verifier::none();
902 let keychain = Keychain::new();
903
904 let task = tokio::spawn(async move {
907 Connection::server(t2.into_inner(), &verifier, keychain, server_version!())
908 .await
909 .unwrap()
910 });
911
912 let _ = receive_version!(t1);
914
915 t1.client_handshake().await.unwrap();
917
918 t1.write_frame_for(&ConnectType::Reconnect {
921 id: 1234,
922 otp: HeapSecretKey::generate(32)
923 .unwrap()
924 .unprotected_into_bytes(),
925 })
926 .await
927 .unwrap();
928
929 task.await.unwrap_err();
931 }
932
933 #[test(tokio::test)]
934 async fn server_should_fail_if_existing_client_otp_is_invalid() {
935 let (mut t1, t2) = FramedTransport::pair(100);
936 let verifier = Verifier::none();
937 let keychain = Keychain::new();
938
939 keychain
940 .insert(
941 1234.to_string(),
942 HeapSecretKey::generate(32).unwrap(),
943 oneshot::channel().1,
944 )
945 .await;
946
947 let task = tokio::spawn(async move {
950 Connection::server(t2.into_inner(), &verifier, keychain, server_version!())
951 .await
952 .unwrap()
953 });
954
955 let _ = receive_version!(t1);
957
958 t1.client_handshake().await.unwrap();
960
961 t1.write_frame_for(&ConnectType::Reconnect {
964 id: 1234,
965 otp: HeapSecretKey::generate(32)
966 .unwrap()
967 .unprotected_into_bytes(),
968 })
969 .await
970 .unwrap();
971
972 task.await.unwrap_err();
974 }
975
976 #[test(tokio::test)]
977 async fn server_should_fail_if_unable_to_exchange_otp_for_reauthentication_with_existing_client(
978 ) {
979 let (mut t1, t2) = FramedTransport::pair(100);
980 let verifier = Verifier::none();
981 let keychain = Keychain::new();
982 let key = HeapSecretKey::generate(32).unwrap();
983
984 keychain
985 .insert(1234.to_string(), key.clone(), oneshot::channel().1)
986 .await;
987
988 let task = tokio::spawn(async move {
991 Connection::server(t2.into_inner(), &verifier, keychain, server_version!())
992 .await
993 .unwrap()
994 });
995
996 let _ = receive_version!(t1);
998
999 t1.client_handshake().await.unwrap();
1001
1002 t1.write_frame_for(&ConnectType::Reconnect {
1005 id: 1234,
1006 otp: key.unprotected_into_bytes(),
1007 })
1008 .await
1009 .unwrap();
1010
1011 t1.write_frame(Frame::new(b"hello")).await.unwrap();
1013
1014 task.await.unwrap_err();
1016 }
1017
1018 #[test(tokio::test)]
1019 async fn server_should_fail_if_unable_to_synchronize_with_existing_client() {
1020 let (mut t1, t2) = FramedTransport::pair(100);
1021 let verifier = Verifier::none();
1022 let keychain = Keychain::new();
1023 let key = HeapSecretKey::generate(32).unwrap();
1024
1025 keychain
1026 .insert(1234.to_string(), key.clone(), oneshot::channel().1)
1027 .await;
1028
1029 let task = tokio::spawn(async move {
1032 Connection::server(t2.into_inner(), &verifier, keychain, server_version!())
1033 .await
1034 .unwrap()
1035 });
1036
1037 let _ = receive_version!(t1);
1039
1040 t1.client_handshake().await.unwrap();
1042
1043 t1.write_frame_for(&ConnectType::Reconnect {
1046 id: 1234,
1047 otp: key.unprotected_into_bytes(),
1048 })
1049 .await
1050 .unwrap();
1051
1052 let _otp = t1.exchange_keys().await.unwrap();
1054
1055 t1.write_frame(b"hello").await.unwrap();
1057
1058 task.await.unwrap_err();
1060 }
1061
1062 #[test(tokio::test)]
1063 async fn server_should_succeed_if_establishes_connection_with_new_client() {
1064 let (mut t1, t2) = FramedTransport::pair(100);
1065 let verifier = Verifier::none();
1066 let keychain = Keychain::new();
1067
1068 let task = tokio::spawn({
1071 let keychain = keychain.clone();
1072 async move {
1073 Connection::server(t2.into_inner(), &verifier, keychain, server_version!())
1074 .await
1075 .unwrap()
1076 }
1077 });
1078
1079 let _ = receive_version!(t1);
1081
1082 t1.client_handshake().await.unwrap();
1084
1085 t1.write_frame_for(&ConnectType::Connect).await.unwrap();
1087
1088 let id = t1.read_frame_as::<ConnectionId>().await.unwrap().unwrap();
1090
1091 t1.authenticate(DummyAuthHandler).await.unwrap();
1093
1094 let otp = t1.exchange_keys().await.unwrap();
1096
1097 let server = task.await.unwrap();
1099
1100 assert_eq!(server.id(), id);
1102
1103 assert!(
1105 keychain
1106 .has_key(id.to_string(), otp.into_heap_secret_key())
1107 .await,
1108 "Missing OTP"
1109 );
1110 }
1111
1112 #[test(tokio::test)]
1113 async fn server_should_succeed_if_establishes_connection_with_existing_client() {
1114 let (mut t1, t2) = FramedTransport::pair(100);
1115 let verifier = Verifier::none();
1116 let keychain = Keychain::new();
1117 let key = HeapSecretKey::generate(32).unwrap();
1118 let id = 1234;
1119
1120 keychain
1121 .insert(id.to_string(), key.clone(), {
1122 let mut backup = Backup::new();
1124
1125 backup.push_frame(Frame::new(b"hello"));
1126 backup.push_frame(Frame::new(b"world"));
1127 backup.increment_sent_cnt();
1128 backup.increment_sent_cnt();
1129
1130 let (tx, rx) = oneshot::channel();
1131 tx.send(backup).unwrap();
1132 rx
1133 })
1134 .await;
1135
1136 let task = tokio::spawn({
1139 let keychain = keychain.clone();
1140 async move {
1141 Connection::server(t2.into_inner(), &verifier, keychain, server_version!())
1142 .await
1143 .unwrap()
1144 }
1145 });
1146
1147 let _ = receive_version!(t1);
1149
1150 t1.client_handshake().await.unwrap();
1152
1153 t1.write_frame_for(&ConnectType::Reconnect {
1156 id: 1234,
1157 otp: key.unprotected_into_bytes(),
1158 })
1159 .await
1160 .unwrap();
1161
1162 let otp = t1.exchange_keys().await.unwrap();
1164
1165 t1.backup.clear();
1167 t1.backup.push_frame(Frame::new(b"foo"));
1168 t1.backup.push_frame(Frame::new(b"bar"));
1169 t1.backup.increment_sent_cnt();
1170 t1.backup.increment_sent_cnt();
1171
1172 t1.synchronize().await.unwrap();
1174
1175 assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"hello");
1177 assert_eq!(t1.read_frame().await.unwrap().unwrap(), b"world");
1178
1179 let mut server = task.await.unwrap();
1181 assert_eq!(server.read_frame().await.unwrap().unwrap(), b"foo");
1182 assert_eq!(server.read_frame().await.unwrap().unwrap(), b"bar");
1183
1184 assert_eq!(server.id(), id);
1186
1187 assert!(
1189 keychain
1190 .has_key(id.to_string(), otp.into_heap_secret_key())
1191 .await,
1192 "Missing OTP"
1193 );
1194 }
1195
1196 #[test(tokio::test)]
1197 async fn client_server_new_connection_e2e_should_establish_connection() {
1198 let (t1, t2) = InmemoryTransport::pair(100);
1199 let verifier = Verifier::none();
1200 let keychain = Keychain::new();
1201
1202 let task = tokio::spawn(async move {
1204 Connection::server(t2, &verifier, keychain, server_version!())
1205 .await
1206 .expect("Failed to connect from server")
1207 });
1208
1209 let mut client = Connection::client(t1, DummyAuthHandler, server_version!())
1211 .await
1212 .expect("Failed to connect from client");
1213 let mut server = task.await.unwrap();
1214
1215 client.write_frame(Frame::new(b"hello")).await.unwrap();
1217 assert_eq!(server.read_frame().await.unwrap().unwrap(), b"hello");
1218 server.write_frame(Frame::new(b"goodbye")).await.unwrap();
1219 assert_eq!(client.read_frame().await.unwrap().unwrap(), b"goodbye");
1220 }
1221
1222 async fn setup_reconnect_scenario() -> (
1224 Connection<InmemoryTransport>,
1225 InmemoryTransport,
1226 Arc<Verifier>,
1227 Keychain<oneshot::Receiver<Backup>>,
1228 ) {
1229 let (t1, t2) = InmemoryTransport::pair(100);
1230 let verifier = Arc::new(Verifier::none());
1231 let keychain = Keychain::new();
1232
1233 let task = {
1235 let verifier = Arc::clone(&verifier);
1236 let keychain = keychain.clone();
1237 tokio::spawn(async move {
1238 Connection::server(t2, &verifier, keychain, server_version!())
1239 .await
1240 .expect("Failed to connect from server")
1241 })
1242 };
1243
1244 let mut client = Connection::client(t1, DummyAuthHandler, server_version!())
1246 .await
1247 .expect("Failed to connect from client");
1248
1249 let server = task.await.unwrap();
1251 drop(server);
1252
1253 let mut t2 = InmemoryTransport::pair(100).0;
1255 t2.link(client.mut_transport().as_mut_inner(), 100);
1256
1257 (client, t2, verifier, keychain)
1258 }
1259
1260 #[test(tokio::test)]
1261 async fn reconnect_should_fail_if_client_side_connection_handshake_fails() {
1262 let (mut client, transport, _verifier, _keychain) = setup_reconnect_scenario().await;
1263 let mut transport = FramedTransport::plain(transport);
1264
1265 let task = tokio::spawn(async move { client.reconnect().await.unwrap() });
1267
1268 send_server_version!(transport);
1270
1271 transport.write_frame(b"hello").await.unwrap();
1273
1274 task.await.unwrap_err();
1276 }
1277
1278 #[test(tokio::test)]
1279 async fn reconnect_should_fail_if_client_side_connection_unable_to_receive_new_connection_id() {
1280 let (mut client, transport, _verifier, _keychain) = setup_reconnect_scenario().await;
1281 let mut transport = FramedTransport::plain(transport);
1282
1283 let task = tokio::spawn(async move { client.reconnect().await.unwrap() });
1285
1286 send_server_version!(transport);
1288
1289 transport.server_handshake().await.unwrap();
1291
1292 drop(transport);
1294
1295 task.await.unwrap_err();
1297 }
1298
1299 #[test(tokio::test)]
1300 async fn reconnect_should_fail_if_client_side_connection_unable_to_exchange_otp_with_server() {
1301 let (mut client, transport, _verifier, keychain) = setup_reconnect_scenario().await;
1302 let mut transport = FramedTransport::plain(transport);
1303
1304 let task = tokio::spawn(async move { client.reconnect().await.unwrap() });
1306
1307 send_server_version!(transport);
1309
1310 transport.server_handshake().await.unwrap();
1312
1313 let (id, otp) = match transport.read_frame_as::<ConnectType>().await {
1315 Ok(Some(ConnectType::Reconnect { id, otp })) => (id, HeapSecretKey::from(otp)),
1316 x => panic!("Unexpected result: {x:?}"),
1317 };
1318
1319 assert!(
1321 keychain.has_key(id.to_string(), otp).await,
1322 "Wrong id or OTP"
1323 );
1324
1325 transport
1327 .write_frame_for(&rand::random::<ConnectionId>())
1328 .await
1329 .unwrap();
1330
1331 transport.write_frame(Frame::new(b"hello")).await.unwrap();
1333
1334 task.await.unwrap_err();
1336 }
1337
1338 #[test(tokio::test)]
1339 async fn reconnect_should_fail_if_client_side_connection_unable_to_synchronize_with_server() {
1340 let (mut client, transport, _verifier, keychain) = setup_reconnect_scenario().await;
1341 let mut transport = FramedTransport::plain(transport);
1342
1343 let task = tokio::spawn(async move { client.reconnect().await.unwrap() });
1345
1346 send_server_version!(transport);
1348
1349 transport.server_handshake().await.unwrap();
1351
1352 let (id, otp) = match transport.read_frame_as::<ConnectType>().await {
1354 Ok(Some(ConnectType::Reconnect { id, otp })) => (id, HeapSecretKey::from(otp)),
1355 x => panic!("Unexpected result: {x:?}"),
1356 };
1357
1358 assert!(
1360 keychain.has_key(id.to_string(), otp).await,
1361 "Wrong id or OTP"
1362 );
1363
1364 transport
1366 .write_frame_for(&rand::random::<ConnectionId>())
1367 .await
1368 .unwrap();
1369
1370 transport.write_frame(Frame::new(b"hello")).await.unwrap();
1372
1373 task.await.unwrap_err();
1375 }
1376
1377 #[test(tokio::test)]
1378 async fn reconnect_should_succeed_if_client_side_connection_fully_connects_and_synchronizes_with_server(
1379 ) {
1380 let (mut client, transport, _verifier, keychain) = setup_reconnect_scenario().await;
1381 let mut transport = FramedTransport::plain(transport);
1382
1383 let client_backup = client.transport().backup.clone();
1385
1386 let task = tokio::spawn(async move {
1388 client.reconnect().await.unwrap();
1389 client
1390 });
1391
1392 send_server_version!(transport);
1394
1395 transport.server_handshake().await.unwrap();
1397
1398 let (id, otp) = match transport.read_frame_as::<ConnectType>().await {
1400 Ok(Some(ConnectType::Reconnect { id, otp })) => (id, HeapSecretKey::from(otp)),
1401 x => panic!("Unexpected result: {x:?}"),
1402 };
1403
1404 let backup = keychain
1406 .remove_if_has_key(id.to_string(), otp)
1407 .await
1408 .into_ok()
1409 .expect("Invalid id or OTP")
1410 .await
1411 .expect("Failed to retrieve backup");
1412
1413 let otp = transport.exchange_keys().await.unwrap();
1415
1416 transport.backup = backup;
1418 transport.synchronize().await.unwrap();
1419
1420 let mut client = task.await.unwrap();
1422 assert_eq!(client.otp(), Some(&otp.into_heap_secret_key()));
1423
1424 assert_eq!(
1427 client.transport().backup.sent_cnt(),
1428 client_backup.sent_cnt(),
1429 "Client backup sent cnt altered"
1430 );
1431 assert_eq!(
1432 client.transport().backup.received_cnt(),
1433 client_backup.received_cnt(),
1434 "Client backup received cnt altered"
1435 );
1436
1437 client.write_frame(Frame::new(b"hello")).await.unwrap();
1440 assert_eq!(transport.read_frame().await.unwrap().unwrap(), b"hello");
1441 transport.write_frame(Frame::new(b"goodbye")).await.unwrap();
1442 assert_eq!(client.read_frame().await.unwrap().unwrap(), b"goodbye");
1443 }
1444
1445 #[test(tokio::test)]
1446 async fn reconnect_should_fail_if_connection_is_server_side() {
1447 let mut connection = Connection::Server {
1448 id: rand::random(),
1449 tx: oneshot::channel().0,
1450 transport: FramedTransport::pair(100).0,
1451 };
1452
1453 assert_eq!(
1454 connection.reconnect().await.unwrap_err().kind(),
1455 io::ErrorKind::Unsupported
1456 );
1457 }
1458
1459 #[test(tokio::test)]
1460 async fn client_server_returning_connection_e2e_should_reestablish_connection() {
1461 let (mut client, transport, verifier, keychain) = setup_reconnect_scenario().await;
1462
1463 let task = tokio::spawn(async move {
1465 Connection::server(transport, &verifier, keychain, server_version!())
1466 .await
1467 .expect("Failed to connect from server")
1468 });
1469
1470 client
1472 .reconnect()
1473 .await
1474 .expect("Failed to reconnect from client");
1475
1476 let mut server = task.await.unwrap();
1478
1479 client.write_frame(Frame::new(b"hello")).await.unwrap();
1481 assert_eq!(server.read_frame().await.unwrap().unwrap(), b"hello");
1482 server.write_frame(Frame::new(b"goodbye")).await.unwrap();
1483 assert_eq!(client.read_frame().await.unwrap().unwrap(), b"goodbye");
1484 }
1485}