1use super::{
2 cipher,
3 handshake::{self, Confirmation},
4 nonce, x25519, Config, AUTHENTICATION_TAG_LENGTH,
5};
6use crate::{
7 utils::codec::{recv_frame, send_frame},
8 Error,
9};
10use bytes::Bytes;
11use chacha20poly1305::{aead::Aead, ChaCha20Poly1305};
12use commonware_codec::{DecodeExt, Encode};
13use commonware_cryptography::Signer;
14use commonware_macros::select;
15use commonware_runtime::{Clock, Sink, Spawner, Stream};
16use commonware_utils::{union, SystemTimeExt as _};
17use rand::{CryptoRng, Rng};
18use std::time::SystemTime;
19
20pub struct IncomingConnection<C: Signer, Si: Sink, St: Stream> {
22 config: Config<C>,
23 sink: Si,
24 stream: St,
25 deadline: SystemTime,
26 ephemeral_public_key: x25519::PublicKey,
27 peer_public_key: C::PublicKey,
28
29 dialer_hello_msg: Bytes,
32}
33
34impl<C: Signer, Si: Sink, St: Stream> IncomingConnection<C, Si, St> {
35 pub async fn verify<E: Clock + Spawner>(
36 context: &E,
37 config: Config<C>,
38 sink: Si,
39 mut stream: St,
40 ) -> Result<Self, Error> {
41 let deadline = context.current() + config.handshake_timeout;
43
44 let msg = select! {
46 _ = context.sleep_until(deadline) => { return Err(Error::HandshakeTimeout) },
47 result = recv_frame(&mut stream, config.max_message_size) => { result? },
48 };
49
50 let hello = handshake::Hello::decode(msg.as_ref()).map_err(Error::UnableToDecode)?;
52 hello.verify(
53 context,
54 &config.crypto.public_key(),
55 &config.namespace,
56 config.synchrony_bound,
57 config.max_handshake_age,
58 )?;
59 Ok(Self {
60 config,
61 sink,
62 stream,
63 deadline,
64 ephemeral_public_key: hello.ephemeral(),
65 peer_public_key: hello.signer(),
66 dialer_hello_msg: msg,
67 })
68 }
69
70 pub fn peer(&self) -> C::PublicKey {
72 self.peer_public_key.clone()
73 }
74
75 pub fn ephemeral(&self) -> x25519::PublicKey {
77 self.ephemeral_public_key
78 }
79}
80
81pub struct Connection<Si: Sink, St: Stream> {
83 sink: Si,
84 stream: St,
85
86 max_message_size: usize,
88
89 cipher_send: ChaCha20Poly1305,
91
92 cipher_recv: ChaCha20Poly1305,
94}
95
96impl<Si: Sink, St: Stream> Connection<Si, St> {
97 pub fn from_preestablished(
101 sink: Si,
102 stream: St,
103 max_message_size: usize,
104 cipher_send: ChaCha20Poly1305,
105 cipher_recv: ChaCha20Poly1305,
106 ) -> Self {
107 Self {
108 sink,
109 stream,
110 max_message_size,
111 cipher_send,
112 cipher_recv,
113 }
114 }
115
116 pub async fn upgrade_dialer<R: Rng + CryptoRng + Spawner + Clock, C: Signer>(
123 mut context: R,
124 mut config: Config<C>,
125 mut sink: Si,
126 mut stream: St,
127 peer: C::PublicKey,
128 ) -> Result<Self, Error> {
129 if peer == config.crypto.public_key() {
131 return Err(Error::DialSelf);
132 }
133
134 let deadline = context.current() + config.handshake_timeout;
136
137 let secret = x25519::new(&mut context);
139
140 let dialer_timestamp = context.current().epoch_millis();
142 let dialer_ephemeral = x25519::PublicKey::from_secret(&secret);
143 let hello_msg = handshake::Hello::sign(
144 &mut config.crypto,
145 &config.namespace,
146 handshake::Info::new(peer.clone(), dialer_ephemeral, dialer_timestamp),
147 )
148 .encode();
149
150 select! {
152 _ = context.sleep_until(deadline) => {
153 return Err(Error::HandshakeTimeout)
154 },
155 result = send_frame(&mut sink, &hello_msg, config.max_message_size) => {
156 result?;
157 },
158 }
159
160 let listener_response_msg = select! {
162 _ = context.sleep_until(deadline) => {
163 return Err(Error::HandshakeTimeout)
164 },
165 result = recv_frame(&mut stream, config.max_message_size) => {
166 result?
167 },
168 };
169
170 let (listener_hello, listener_confirmation) =
172 <(handshake::Hello<C::PublicKey>, Confirmation)>::decode(
173 listener_response_msg.as_ref(),
174 )
175 .map_err(Error::UnableToDecode)?;
176 listener_hello.verify(
177 &context,
178 &config.crypto.public_key(),
179 &config.namespace,
180 config.synchrony_bound,
181 config.max_handshake_age,
182 )?;
183
184 if peer != listener_hello.signer() {
186 return Err(Error::WrongPeer);
187 }
188
189 let shared_secret = secret.diffie_hellman(listener_hello.ephemeral().as_ref());
191 if !shared_secret.was_contributory() {
192 return Err(Error::SharedSecretNotContributory);
193 }
194
195 let hello_transcript = union(&hello_msg, &listener_hello.encode());
197 let cipher::Full {
198 confirmation,
199 traffic,
200 } = cipher::derive_directional(
201 shared_secret.as_bytes(),
202 &config.namespace,
203 &hello_transcript,
204 )?;
205
206 let cipher::Directional { d2l, l2d } = confirmation;
208 listener_confirmation.verify(l2d, &hello_transcript)?;
209
210 let full_transcript = union(&hello_msg, &listener_response_msg);
212 let confirmation_msg = Confirmation::create(d2l, &full_transcript)?.encode();
213 select! {
214 _ = context.sleep_until(deadline) => {
215 return Err(Error::HandshakeTimeout)
216 },
217 result = send_frame(
218 &mut sink,
219 &confirmation_msg,
220 config.max_message_size,
221 ) => {
222 result?;
223 },
224 }
225
226 Ok(Self {
228 sink,
229 stream,
230 max_message_size: config.max_message_size,
231 cipher_send: traffic.d2l,
232 cipher_recv: traffic.l2d,
233 })
234 }
235
236 pub async fn upgrade_listener<R: Rng + CryptoRng + Spawner + Clock, C: Signer>(
245 mut context: R,
246 incoming: IncomingConnection<C, Si, St>,
247 ) -> Result<Self, Error> {
248 let max_message_size = incoming.config.max_message_size;
250 let mut crypto = incoming.config.crypto;
251 let namespace = incoming.config.namespace;
252 let mut sink = incoming.sink;
253 let mut stream = incoming.stream;
254
255 let secret = x25519::new(&mut context);
257
258 let timestamp = context.current().epoch_millis();
260 let listener_ephemeral = x25519::PublicKey::from_secret(&secret);
261 let hello = handshake::Hello::sign(
262 &mut crypto,
263 &namespace,
264 handshake::Info::new(incoming.peer_public_key, listener_ephemeral, timestamp),
265 );
266
267 let shared_secret = secret.diffie_hellman(incoming.ephemeral_public_key.as_ref());
269 if !shared_secret.was_contributory() {
270 return Err(Error::SharedSecretNotContributory);
271 }
272
273 let hello_transcript = union(&incoming.dialer_hello_msg, &hello.encode());
275 let cipher::Full {
276 confirmation,
277 traffic,
278 } = cipher::derive_directional(shared_secret.as_bytes(), &namespace, &hello_transcript)?;
279
280 let cipher::Directional { l2d, d2l } = confirmation;
282 let confirmation = Confirmation::create(l2d, &hello_transcript)?;
283 let response_msg = (hello, confirmation).encode();
284 select! {
285 _ = context.sleep_until(incoming.deadline) => {
286 return Err(Error::HandshakeTimeout)
287 },
288 result = send_frame(&mut sink, &response_msg, max_message_size) => {
289 result?;
290 },
291 }
292
293 let confirmation_msg = select! {
295 _ = context.sleep_until(incoming.deadline) => {
296 return Err(Error::HandshakeTimeout)
297 },
298 result = recv_frame(&mut stream, max_message_size) => {
299 result?
300 },
301 };
302
303 let full_transcript = union(&incoming.dialer_hello_msg, &response_msg);
305 Confirmation::decode(confirmation_msg.as_ref())
306 .map_err(Error::UnableToDecode)?
307 .verify(d2l, &full_transcript)?;
308
309 Ok(Connection {
311 sink,
312 stream,
313 max_message_size,
314 cipher_send: traffic.l2d,
315 cipher_recv: traffic.d2l,
316 })
317 }
318
319 pub fn split(self) -> (Sender<Si>, Receiver<St>) {
324 (
325 Sender {
326 sink: self.sink,
327 max_message_size: self.max_message_size,
328 cipher: self.cipher_send,
329 nonce: nonce::Info::default(),
330 },
331 Receiver {
332 stream: self.stream,
333 max_message_size: self.max_message_size,
334 cipher: self.cipher_recv,
335 nonce: nonce::Info::default(),
336 },
337 )
338 }
339}
340
341pub struct Sender<Si: Sink> {
343 sink: Si,
344 max_message_size: usize,
345 cipher: ChaCha20Poly1305,
346 nonce: nonce::Info,
347}
348
349impl<Si: Sink> crate::Sender for Sender<Si> {
350 async fn send(&mut self, msg: &[u8]) -> Result<(), Error> {
351 let nonce = self.nonce.next()?;
353 let msg = self
354 .cipher
355 .encrypt(&nonce, msg.as_ref())
356 .map_err(|_| Error::EncryptionFailed)?;
357
358 send_frame(
360 &mut self.sink,
361 &msg,
362 self.max_message_size + AUTHENTICATION_TAG_LENGTH,
363 )
364 .await?;
365 Ok(())
366 }
367}
368
369pub struct Receiver<St: Stream> {
371 stream: St,
372 max_message_size: usize,
373 cipher: ChaCha20Poly1305,
374 nonce: nonce::Info,
375}
376
377impl<St: Stream> crate::Receiver for Receiver<St> {
378 async fn receive(&mut self) -> Result<Bytes, Error> {
379 let msg = recv_frame(
381 &mut self.stream,
382 self.max_message_size + AUTHENTICATION_TAG_LENGTH,
383 )
384 .await?;
385
386 let nonce = self.nonce.next()?;
388 self.cipher
389 .decrypt(&nonce, msg.as_ref())
390 .map(Bytes::from)
391 .map_err(|_| Error::DecryptionFailed)
392 }
393}
394
395#[cfg(test)]
396mod tests {
397 use super::*;
398 use crate::{Receiver as _, Sender as _};
399 use chacha20poly1305::KeyInit;
400 use commonware_cryptography::{
401 ed25519::{PrivateKey, PublicKey},
402 PrivateKeyExt as _,
403 };
404 use commonware_runtime::{deterministic, mocks, Metrics, Runner};
405 use std::time::Duration;
406
407 #[test]
408 fn test_decryption_failure() {
409 let executor = deterministic::Runner::default();
410 executor.start(|_| async move {
411 let cipher = ChaCha20Poly1305::new(&[0u8; 32].into());
412 let (mut sink, stream) = mocks::Channel::init();
413 let mut receiver = Receiver {
414 cipher,
415 stream,
416 max_message_size: 1024,
417 nonce: nonce::Info::default(),
418 };
419
420 let initial_nonce = receiver.nonce;
422
423 send_frame(&mut sink, b"invalid data", receiver.max_message_size)
425 .await
426 .unwrap();
427
428 let result = receiver.receive().await;
430 assert!(matches!(result, Err(Error::DecryptionFailed)));
431
432 let final_nonce = receiver.nonce;
434 assert_ne!(initial_nonce, final_nonce);
435 });
436 }
437
438 #[test]
439 fn test_send_too_large() {
440 let executor = deterministic::Runner::default();
441 executor.start(|_| async move {
442 let cipher = ChaCha20Poly1305::new(&[0u8; 32].into());
443 let message = b"hello world";
444 let (sink, _) = mocks::Channel::init();
445 let mut sender = Sender {
446 cipher,
447 sink,
448 max_message_size: message.len() - 1,
449 nonce: nonce::Info::default(),
450 };
451
452 let result = sender.send(message).await;
453 let expected_length = message.len() + AUTHENTICATION_TAG_LENGTH;
454 assert!(matches!(result, Err(Error::SendTooLarge(n)) if n == expected_length));
455 });
456 }
457
458 #[test]
459 fn test_receive_too_large() {
460 let executor = deterministic::Runner::default();
461 executor.start(|_| async move {
462 let cipher = ChaCha20Poly1305::new(&[0u8; 32].into());
463 let message = b"hello world";
464 let (sink, stream) = mocks::Channel::init();
465
466 let mut sender = Sender {
467 cipher: cipher.clone(),
468 sink,
469 max_message_size: message.len(),
470 nonce: nonce::Info::default(),
471 };
472 let mut receiver = Receiver {
473 cipher,
474 stream,
475 max_message_size: message.len() - 1,
476 nonce: nonce::Info::default(),
477 };
478
479 sender.send(message).await.unwrap();
480 let result = receiver.receive().await;
481 let expected_length = message.len() + AUTHENTICATION_TAG_LENGTH;
482 assert!(matches!(result, Err(Error::RecvTooLarge(n)) if n == expected_length));
483 });
484 }
485
486 #[test]
487 fn test_send_receive() {
488 let executor = deterministic::Runner::default();
489 executor.start(|_| async move {
490 let cipher = ChaCha20Poly1305::new(&[0u8; 32].into());
491 let max_message_size = 1024;
492
493 let (dialer_sink, listener_stream) = mocks::Channel::init();
495 let (listener_sink, dialer_stream) = mocks::Channel::init();
496
497 let connection_dialer = Connection::from_preestablished(
499 dialer_sink,
500 dialer_stream,
501 max_message_size,
502 cipher.clone(),
503 cipher.clone(),
504 );
505
506 let connection_listener = Connection::from_preestablished(
508 listener_sink,
509 listener_stream,
510 max_message_size,
511 cipher.clone(),
512 cipher,
513 );
514
515 let (mut dialer_sender, mut dialer_receiver) = connection_dialer.split();
517 let (mut listener_sender, mut listener_receiver) = connection_listener.split();
518
519 let msg1 = b"hello from dialer";
521 dialer_sender.send(msg1).await.unwrap();
522 let received1 = listener_receiver.receive().await.unwrap();
523 assert_eq!(received1, &msg1[..]);
524
525 let msg2 = b"hello from listener";
527 listener_sender.send(msg2).await.unwrap();
528 let received2 = dialer_receiver.receive().await.unwrap();
529 assert_eq!(received2, &msg2[..]);
530
531 let messages_to_listener = vec![b"msg1", b"msg2", b"msg3"];
533 for msg in &messages_to_listener {
534 dialer_sender.send(*msg).await.unwrap();
535 let received = listener_receiver.receive().await.unwrap();
536 assert_eq!(received, &msg[..]);
537 }
538 let messages_to_dialer = vec![b"reply1", b"reply2", b"reply3"];
539 for msg in &messages_to_dialer {
540 listener_sender.send(*msg).await.unwrap();
541 let received = dialer_receiver.receive().await.unwrap();
542 assert_eq!(received, &msg[..]);
543 }
544 });
545 }
546 #[test]
547 fn test_full_connection_establishment_and_exchange() {
548 let executor = deterministic::Runner::default();
549 executor.start(|context| async move {
550 let dialer_crypto = PrivateKey::from_seed(0);
552 let listener_crypto = PrivateKey::from_seed(1);
553
554 let (dialer_sink, listener_stream) = mocks::Channel::init();
556 let (listener_sink, dialer_stream) = mocks::Channel::init();
557
558 let dialer_config = Config {
560 crypto: dialer_crypto.clone(),
561 namespace: b"test_namespace".to_vec(),
562 max_message_size: 1024,
563 synchrony_bound: Duration::from_secs(5),
564 max_handshake_age: Duration::from_secs(5),
565 handshake_timeout: Duration::from_secs(5),
566 };
567
568 let listener_config = Config {
570 crypto: listener_crypto.clone(),
571 namespace: b"test_namespace".to_vec(),
572 max_message_size: 1024,
573 synchrony_bound: Duration::from_secs(5),
574 max_handshake_age: Duration::from_secs(5),
575 handshake_timeout: Duration::from_secs(5),
576 };
577
578 let listener_handle = context.with_label("listener").spawn({
580 move |context| async move {
581 let incoming = IncomingConnection::verify(
582 &context,
583 listener_config,
584 listener_sink,
585 listener_stream,
586 )
587 .await
588 .unwrap();
589 Connection::upgrade_listener(context, incoming)
590 .await
591 .unwrap()
592 }
593 });
594
595 let dialer_connection = Connection::upgrade_dialer(
597 context.clone(),
598 dialer_config,
599 dialer_sink,
600 dialer_stream,
601 listener_crypto.public_key(),
602 )
603 .await
604 .unwrap();
605
606 let listener_connection = listener_handle.await.unwrap();
608
609 let (mut dialer_sender, mut dialer_receiver) = dialer_connection.split();
611 let (mut listener_sender, mut listener_receiver) = listener_connection.split();
612
613 let message1 = b"Hello from dialer";
615 dialer_sender.send(message1).await.unwrap();
616 dialer_sender.send(message1).await.unwrap();
617 let received = listener_receiver.receive().await.unwrap();
618 assert_eq!(&received[..], &message1[..]);
619 let received = listener_receiver.receive().await.unwrap();
620 assert_eq!(&received[..], &message1[..]);
621
622 let message2 = b"Hello from listener";
624 listener_sender.send(message2).await.unwrap();
625 listener_sender.send(message2).await.unwrap();
626 let received = dialer_receiver.receive().await.unwrap();
627 assert_eq!(&received[..], &message2[..]);
628 let received = dialer_receiver.receive().await.unwrap();
629 assert_eq!(&received[..], &message2[..]);
630 });
631 }
632
633 #[test]
634 fn test_upgrade_dialer_wrong_peer() {
635 let executor = deterministic::Runner::default();
636 executor.start(|context| async move {
637 let dialer_crypto = PrivateKey::from_seed(0);
639 let expected_peer = PrivateKey::from_seed(1).public_key();
640 let mut actual_peer = PrivateKey::from_seed(2);
641
642 let (dialer_sink, mut peer_stream) = mocks::Channel::init();
644 let (mut peer_sink, dialer_stream) = mocks::Channel::init();
645
646 let dialer_config = Config {
648 crypto: dialer_crypto,
649 namespace: b"test_namespace".to_vec(),
650 max_message_size: 1024,
651 synchrony_bound: Duration::from_secs(5),
652 max_handshake_age: Duration::from_secs(5),
653 handshake_timeout: Duration::from_secs(5),
654 };
655 let peer_config = dialer_config.clone();
656
657 context.with_label("mock_peer").spawn({
659 move |mut context| async move {
660 use chacha20poly1305::KeyInit;
661
662 let msg = recv_frame(&mut peer_stream, 1024).await.unwrap();
664 let _ = handshake::Hello::<PublicKey>::decode(msg).unwrap();
665
666 let mock_secret = [1u8; 32];
668 let mock_cipher = ChaCha20Poly1305::new(&mock_secret.into());
669
670 let secret = x25519::new(&mut context);
672 let timestamp = context.current().epoch_millis();
673 let info = handshake::Info::new(
674 peer_config.crypto.public_key(),
675 x25519::PublicKey::from_secret(&secret),
676 timestamp,
677 );
678 let hello =
679 handshake::Hello::sign(&mut actual_peer, &peer_config.namespace, info);
680
681 let fake_transcript = b"fake_transcript_data";
683 let confirmation = Confirmation::create(mock_cipher, fake_transcript).unwrap();
684
685 send_frame(&mut peer_sink, &(hello, confirmation).encode(), 1024)
686 .await
687 .unwrap();
688 }
689 });
690
691 let result = Connection::upgrade_dialer(
693 context,
694 dialer_config,
695 dialer_sink,
696 dialer_stream,
697 expected_peer,
698 )
699 .await;
700
701 assert!(matches!(result, Err(Error::WrongPeer)));
703 });
704 }
705
706 #[test]
707 fn test_upgrade_dialer_non_contributory_secret() {
708 let executor = deterministic::Runner::default();
709 executor.start(|context| async move {
710 let dialer_crypto = PrivateKey::from_seed(0);
712 let mut listener_crypto = PrivateKey::from_seed(1);
713 let listener_public_key = listener_crypto.public_key();
714
715 let (dialer_sink, mut peer_stream) = mocks::Channel::init();
717 let (mut peer_sink, dialer_stream) = mocks::Channel::init();
718
719 let dialer_config = Config {
721 crypto: dialer_crypto,
722 namespace: b"test_namespace".to_vec(),
723 max_message_size: 1024,
724 synchrony_bound: Duration::from_secs(5),
725 max_handshake_age: Duration::from_secs(5),
726 handshake_timeout: Duration::from_secs(5),
727 };
728
729 context.with_label("mock_peer").spawn({
731 let namespace = dialer_config.namespace.clone();
732 let recipient_pk = dialer_config.crypto.public_key();
733 move |context| async move {
734 use chacha20poly1305::KeyInit;
735
736 let msg = recv_frame(&mut peer_stream, 1024).await.unwrap();
738 let _ = handshake::Hello::<PublicKey>::decode(msg).unwrap();
739
740 let mock_secret = [1u8; 32];
742 let mock_cipher = ChaCha20Poly1305::new(&mock_secret.into());
743
744 let timestamp = context.current().epoch_millis();
746 let info = handshake::Info::new(
747 recipient_pk,
748 x25519::PublicKey::from_bytes([0u8; 32]),
749 timestamp,
750 );
751
752 let hello = handshake::Hello::sign(&mut listener_crypto, &namespace, info);
754
755 let fake_transcript = b"fake_transcript_for_non_contributory_test";
757 let confirmation = Confirmation::create(mock_cipher, fake_transcript).unwrap();
758
759 send_frame(&mut peer_sink, &(hello, confirmation).encode(), 1024)
761 .await
762 .unwrap();
763 }
764 });
765
766 let result = Connection::upgrade_dialer(
768 context,
769 dialer_config,
770 dialer_sink,
771 dialer_stream,
772 listener_public_key,
773 )
774 .await;
775
776 assert!(matches!(result, Err(Error::SharedSecretNotContributory)));
778 });
779 }
780
781 #[test]
782 fn test_upgrade_listener_non_contributory_secret() {
783 let executor = deterministic::Runner::default();
784 executor.start(|context| async move {
785 let mut dialer_crypto = PrivateKey::from_seed(0);
787 let listener_crypto = PrivateKey::from_seed(1);
788
789 let (mut dialer_sink, listener_stream) = mocks::Channel::init();
791 let (listener_sink, _dialer_stream) = mocks::Channel::init();
792
793 let listener_config = Config {
795 crypto: listener_crypto.clone(),
796 namespace: b"test_namespace".to_vec(),
797 max_message_size: 1024,
798 synchrony_bound: Duration::from_secs(5),
799 max_handshake_age: Duration::from_secs(5),
800 handshake_timeout: Duration::from_secs(5),
801 };
802
803 let info = handshake::Info::new(
805 listener_config.crypto.public_key(),
806 x25519::PublicKey::from_bytes([0u8; 32]),
807 context.current().epoch_millis(),
808 );
809
810 let hello =
812 handshake::Hello::sign(&mut dialer_crypto, &listener_config.namespace, info);
813
814 send_frame(&mut dialer_sink, &hello.encode(), 1024)
816 .await
817 .unwrap();
818
819 let incoming = IncomingConnection::verify(
821 &context,
822 listener_config,
823 listener_sink,
824 listener_stream,
825 )
826 .await
827 .unwrap();
828
829 let result = Connection::upgrade_listener(context, incoming).await;
831
832 assert!(matches!(result, Err(Error::SharedSecretNotContributory)));
834 });
835 }
836
837 #[test]
838 fn test_listener_rejects_hello_signed_with_own_key() {
839 let executor = deterministic::Runner::default();
840 executor.start(|context| async move {
841 let self_crypto = PrivateKey::from_seed(0);
842 let self_public_key = self_crypto.public_key();
843
844 let config = Config {
845 crypto: self_crypto.clone(),
846 namespace: b"test_self_connect_namespace".to_vec(),
847 max_message_size: 1024,
848 synchrony_bound: Duration::from_secs(5),
849 max_handshake_age: Duration::from_secs(5),
850 handshake_timeout: Duration::from_secs(1),
851 };
852
853 let (mut dialer_sink, listener_stream) = mocks::Channel::init();
855 let (listener_reply_sink, _dialer_stream) = mocks::Channel::init();
857
858 let listener_config = config.clone();
859 let listener_handle =
860 context
861 .with_label("self_listener")
862 .spawn(move |task_ctx| async move {
863 IncomingConnection::verify(
864 &task_ctx,
865 listener_config,
866 listener_reply_sink,
867 listener_stream,
868 )
869 .await
870 });
871
872 let max_msg_size = config.max_message_size;
873 let namespace = config.namespace.clone();
874 let handshake_sender_handle =
875 context
876 .with_label("handshake_sender")
877 .spawn(move |task_ctx| {
878 let mut crypto_for_signing = self_crypto.clone();
879 let recipient_pk = self_public_key.clone();
880 let ephemeral_pk = super::x25519::PublicKey::from_bytes([0xCDu8; 32]);
881
882 async move {
883 let timestamp = task_ctx.current().epoch_millis();
884 let info =
885 super::handshake::Info::new(recipient_pk, ephemeral_pk, timestamp);
886 let hello = super::handshake::Hello::sign(
887 &mut crypto_for_signing,
888 &namespace,
889 info,
890 );
891 crate::utils::codec::send_frame(
892 &mut dialer_sink,
893 &hello.encode(),
894 max_msg_size,
895 )
896 .await
897 }
898 });
899
900 handshake_sender_handle.await.unwrap().unwrap();
902
903 let listener_result = listener_handle.await.unwrap();
904 assert!(matches!(listener_result, Err(Error::HelloUsesOurKey)));
905 });
906 }
907
908 #[test]
909 fn test_three_message_handshake_protocol() {
910 let executor = deterministic::Runner::default();
911 executor.start(|context| async move {
912 let dialer_crypto = PrivateKey::from_seed(0);
914 let listener_crypto = PrivateKey::from_seed(1);
915
916 let (dialer_sink, listener_stream) = mocks::Channel::init();
918 let (listener_sink, dialer_stream) = mocks::Channel::init();
919
920 let dialer_config = Config {
922 crypto: dialer_crypto.clone(),
923 namespace: b"test_3msg_namespace".to_vec(),
924 max_message_size: 1024,
925 synchrony_bound: Duration::from_secs(5),
926 max_handshake_age: Duration::from_secs(5),
927 handshake_timeout: Duration::from_secs(5),
928 };
929
930 let listener_config = Config {
932 crypto: listener_crypto.clone(),
933 namespace: b"test_3msg_namespace".to_vec(),
934 max_message_size: 1024,
935 synchrony_bound: Duration::from_secs(5),
936 max_handshake_age: Duration::from_secs(5),
937 handshake_timeout: Duration::from_secs(5),
938 };
939
940 let listener_handle = context.with_label("listener").spawn({
942 move |context| async move {
943 let incoming = IncomingConnection::verify(
944 &context,
945 listener_config,
946 listener_sink,
947 listener_stream,
948 )
949 .await
950 .unwrap();
951 Connection::upgrade_listener(context, incoming)
952 .await
953 .unwrap()
954 }
955 });
956
957 let dialer_connection = Connection::upgrade_dialer(
959 context.clone(),
960 dialer_config,
961 dialer_sink,
962 dialer_stream,
963 listener_crypto.public_key(),
964 )
965 .await
966 .unwrap();
967
968 let listener_connection = listener_handle.await.unwrap();
970
971 let (mut dialer_sender, mut dialer_receiver) = dialer_connection.split();
973 let (mut listener_sender, mut listener_receiver) = listener_connection.split();
974
975 let message1 = b"Hello from dialer after 3-msg handshake";
977 dialer_sender.send(message1).await.unwrap();
978 let received = listener_receiver.receive().await.unwrap();
979 assert_eq!(&received[..], &message1[..]);
980
981 let message2 = b"Hello from listener after 3-msg handshake";
982 listener_sender.send(message2).await.unwrap();
983 let received = dialer_receiver.receive().await.unwrap();
984 assert_eq!(&received[..], &message2[..]);
985 });
986 }
987
988 #[test]
989 fn test_upgrade_dialer_rejects_connecting_to_self() {
990 let executor = deterministic::Runner::default();
991 executor.start(|context| async move {
992 let self_crypto = PrivateKey::from_seed(0);
994 let self_public_key = self_crypto.public_key();
995
996 let dialer_config = Config {
998 crypto: self_crypto.clone(),
999 namespace: b"test_dial_self_direct".to_vec(),
1000 max_message_size: 1024,
1001 synchrony_bound: Duration::from_secs(5),
1002 max_handshake_age: Duration::from_secs(5),
1003 handshake_timeout: Duration::from_secs(1),
1004 };
1005
1006 let (dialer_sink, _unused_stream) = mocks::Channel::init();
1008 let (_unused_sink, dialer_stream) = mocks::Channel::init();
1009
1010 let result = Connection::upgrade_dialer(
1012 context.clone(),
1013 dialer_config,
1014 dialer_sink,
1015 dialer_stream,
1016 self_public_key.clone(),
1017 )
1018 .await;
1019
1020 assert!(matches!(result, Err(Error::DialSelf)));
1022 });
1023 }
1024}