commonware_stream/public_key/
connection.rs

1use super::{
2    cipher,
3    handshake::{self, Confirmation},
4    nonce, x25519, Config, AUTHENTICATION_TAG_LENGTH,
5};
6use crate::{
7    utils::codec::{recv_frame, send_frame},
8    Error,
9};
10use bytes::Bytes;
11use chacha20poly1305::{aead::Aead, ChaCha20Poly1305};
12use commonware_codec::{DecodeExt, Encode};
13use commonware_cryptography::Signer;
14use commonware_macros::select;
15use commonware_runtime::{Clock, Sink, Spawner, Stream};
16use commonware_utils::{union, SystemTimeExt as _};
17use rand::{CryptoRng, Rng};
18use std::time::SystemTime;
19
20/// An incoming connection with a verified peer handshake.
21pub struct IncomingConnection<C: Signer, Si: Sink, St: Stream> {
22    config: Config<C>,
23    sink: Si,
24    stream: St,
25    deadline: SystemTime,
26    ephemeral_public_key: x25519::PublicKey,
27    peer_public_key: C::PublicKey,
28
29    /// Stores the raw bytes of the dialer hello message.
30    /// Necessary for the cipher derivation.
31    dialer_hello_msg: Bytes,
32}
33
34impl<C: Signer, Si: Sink, St: Stream> IncomingConnection<C, Si, St> {
35    pub async fn verify<E: Clock + Spawner>(
36        context: &E,
37        config: Config<C>,
38        sink: Si,
39        mut stream: St,
40    ) -> Result<Self, Error> {
41        // Set handshake deadline
42        let deadline = context.current() + config.handshake_timeout;
43
44        // Wait for up to handshake timeout for response (Message 1)
45        let msg = select! {
46            _ = context.sleep_until(deadline) => { return Err(Error::HandshakeTimeout) },
47            result = recv_frame(&mut stream, config.max_message_size) => { result? },
48        };
49
50        // Verify hello message from peer
51        let hello = handshake::Hello::decode(msg.as_ref()).map_err(Error::UnableToDecode)?;
52        hello.verify(
53            context,
54            &config.crypto.public_key(),
55            &config.namespace,
56            config.synchrony_bound,
57            config.max_handshake_age,
58        )?;
59        Ok(Self {
60            config,
61            sink,
62            stream,
63            deadline,
64            ephemeral_public_key: hello.ephemeral(),
65            peer_public_key: hello.signer(),
66            dialer_hello_msg: msg,
67        })
68    }
69
70    /// The public key of the peer attempting to connect.
71    pub fn peer(&self) -> C::PublicKey {
72        self.peer_public_key.clone()
73    }
74
75    /// The ephemeral public key of the peer attempting to connect.
76    pub fn ephemeral(&self) -> x25519::PublicKey {
77        self.ephemeral_public_key
78    }
79}
80
81/// A fully initialized connection with some peer.
82pub struct Connection<Si: Sink, St: Stream> {
83    sink: Si,
84    stream: St,
85
86    /// The maximum size of a message that can be sent or received.
87    max_message_size: usize,
88
89    /// The cipher used for sending messages.
90    cipher_send: ChaCha20Poly1305,
91
92    /// The cipher used for receiving messages.
93    cipher_recv: ChaCha20Poly1305,
94}
95
96impl<Si: Sink, St: Stream> Connection<Si, St> {
97    /// Create a new connection from pre-established components.
98    ///
99    /// This is useful in tests, or when upgrading a connection that has already been verified.
100    pub fn from_preestablished(
101        sink: Si,
102        stream: St,
103        max_message_size: usize,
104        cipher_send: ChaCha20Poly1305,
105        cipher_recv: ChaCha20Poly1305,
106    ) -> Self {
107        Self {
108            sink,
109            stream,
110            max_message_size,
111            cipher_send,
112            cipher_recv,
113        }
114    }
115
116    /// Attempt to upgrade a raw connection we initiated as the dialer.
117    ///
118    /// This implements the 3-message handshake protocol where the dialer:
119    /// 1. Sends initial `hello` message to the listener
120    /// 2. Receives listener response with `hello + confirmation`
121    /// 3. Sends `confirmation` to the listener
122    pub async fn upgrade_dialer<R: Rng + CryptoRng + Spawner + Clock, C: Signer>(
123        mut context: R,
124        mut config: Config<C>,
125        mut sink: Si,
126        mut stream: St,
127        peer: C::PublicKey,
128    ) -> Result<Self, Error> {
129        // Ensure we are not trying to connect to ourselves
130        if peer == config.crypto.public_key() {
131            return Err(Error::DialSelf);
132        }
133
134        // Set handshake deadline
135        let deadline = context.current() + config.handshake_timeout;
136
137        // Generate shared secret
138        let secret = x25519::new(&mut context);
139
140        // Send hello (Message 1)
141        let dialer_timestamp = context.current().epoch_millis();
142        let dialer_ephemeral = x25519::PublicKey::from_secret(&secret);
143        let hello_msg = handshake::Hello::sign(
144            &mut config.crypto,
145            &config.namespace,
146            handshake::Info::new(peer.clone(), dialer_ephemeral, dialer_timestamp),
147        )
148        .encode();
149
150        // Wait for up to handshake timeout to send
151        select! {
152            _ = context.sleep_until(deadline) => {
153                return Err(Error::HandshakeTimeout)
154            },
155            result = send_frame(&mut sink, &hello_msg, config.max_message_size) => {
156                result?;
157            },
158        }
159
160        // Wait for listener's hello + confirmation (Message 2)
161        let listener_response_msg = select! {
162            _ = context.sleep_until(deadline) => {
163                return Err(Error::HandshakeTimeout)
164            },
165            result = recv_frame(&mut stream, config.max_message_size) => {
166                result?
167            },
168        };
169
170        // Verify listener's hello
171        let (listener_hello, listener_confirmation) =
172            <(handshake::Hello<C::PublicKey>, Confirmation)>::decode(
173                listener_response_msg.as_ref(),
174            )
175            .map_err(Error::UnableToDecode)?;
176        listener_hello.verify(
177            &context,
178            &config.crypto.public_key(),
179            &config.namespace,
180            config.synchrony_bound,
181            config.max_handshake_age,
182        )?;
183
184        // Ensure we connected to the right peer
185        if peer != listener_hello.signer() {
186            return Err(Error::WrongPeer);
187        }
188
189        // Derive shared secret and ensure it is contributory
190        let shared_secret = secret.diffie_hellman(listener_hello.ephemeral().as_ref());
191        if !shared_secret.was_contributory() {
192            return Err(Error::SharedSecretNotContributory);
193        }
194
195        // Create ciphers
196        let hello_transcript = union(&hello_msg, &listener_hello.encode());
197        let cipher::Full {
198            confirmation,
199            traffic,
200        } = cipher::derive_directional(
201            shared_secret.as_bytes(),
202            &config.namespace,
203            &hello_transcript,
204        )?;
205
206        // Verify listener's confirmation
207        let cipher::Directional { d2l, l2d } = confirmation;
208        listener_confirmation.verify(l2d, &hello_transcript)?;
209
210        // Create our own confirmation (Message 3)
211        let full_transcript = union(&hello_msg, &listener_response_msg);
212        let confirmation_msg = Confirmation::create(d2l, &full_transcript)?.encode();
213        select! {
214            _ = context.sleep_until(deadline) => {
215                return Err(Error::HandshakeTimeout)
216            },
217            result = send_frame(
218                &mut sink,
219                &confirmation_msg,
220                config.max_message_size,
221            ) => {
222                result?;
223            },
224        }
225
226        // Connection successfully established
227        Ok(Self {
228            sink,
229            stream,
230            max_message_size: config.max_message_size,
231            cipher_send: traffic.d2l,
232            cipher_recv: traffic.l2d,
233        })
234    }
235
236    /// Attempt to upgrade a connection we received as the listener.
237    ///
238    /// This implements the last two steps of the 3-message handshake protocol. The first step,
239    /// where the listener receives the dialer's `hello`, is handled by [IncomingConnection::verify].
240    ///
241    /// The last two steps are:
242    /// 2. Sends a response with `hello + confirmation`
243    /// 3. Receives the dialer's `confirmation`
244    pub async fn upgrade_listener<R: Rng + CryptoRng + Spawner + Clock, C: Signer>(
245        mut context: R,
246        incoming: IncomingConnection<C, Si, St>,
247    ) -> Result<Self, Error> {
248        // Extract fields
249        let max_message_size = incoming.config.max_message_size;
250        let mut crypto = incoming.config.crypto;
251        let namespace = incoming.config.namespace;
252        let mut sink = incoming.sink;
253        let mut stream = incoming.stream;
254
255        // Generate personal secret
256        let secret = x25519::new(&mut context);
257
258        // Create hello
259        let timestamp = context.current().epoch_millis();
260        let listener_ephemeral = x25519::PublicKey::from_secret(&secret);
261        let hello = handshake::Hello::sign(
262            &mut crypto,
263            &namespace,
264            handshake::Info::new(incoming.peer_public_key, listener_ephemeral, timestamp),
265        );
266
267        // Derive shared secret and ensure it is contributory
268        let shared_secret = secret.diffie_hellman(incoming.ephemeral_public_key.as_ref());
269        if !shared_secret.was_contributory() {
270            return Err(Error::SharedSecretNotContributory);
271        }
272
273        // Create ciphers
274        let hello_transcript = union(&incoming.dialer_hello_msg, &hello.encode());
275        let cipher::Full {
276            confirmation,
277            traffic,
278        } = cipher::derive_directional(shared_secret.as_bytes(), &namespace, &hello_transcript)?;
279
280        // Create and send hello + confirmation (Message 2)
281        let cipher::Directional { l2d, d2l } = confirmation;
282        let confirmation = Confirmation::create(l2d, &hello_transcript)?;
283        let response_msg = (hello, confirmation).encode();
284        select! {
285            _ = context.sleep_until(incoming.deadline) => {
286                return Err(Error::HandshakeTimeout)
287            },
288            result = send_frame(&mut sink, &response_msg, max_message_size) => {
289                result?;
290            },
291        }
292
293        // Wait for dialer confirmation (Message 3)
294        let confirmation_msg = select! {
295            _ = context.sleep_until(incoming.deadline) => {
296                return Err(Error::HandshakeTimeout)
297            },
298            result = recv_frame(&mut stream, max_message_size) => {
299                result?
300            },
301        };
302
303        // Verify dialer's confirmation
304        let full_transcript = union(&incoming.dialer_hello_msg, &response_msg);
305        Confirmation::decode(confirmation_msg.as_ref())
306            .map_err(Error::UnableToDecode)?
307            .verify(d2l, &full_transcript)?;
308
309        // Connection successfully established
310        Ok(Connection {
311            sink,
312            stream,
313            max_message_size,
314            cipher_send: traffic.l2d,
315            cipher_recv: traffic.d2l,
316        })
317    }
318
319    /// Split the connection into a `Sender` and `Receiver`.
320    ///
321    /// This pattern is commonly used to efficiently send and receive messages
322    /// over the same connection concurrently.
323    pub fn split(self) -> (Sender<Si>, Receiver<St>) {
324        (
325            Sender {
326                sink: self.sink,
327                max_message_size: self.max_message_size,
328                cipher: self.cipher_send,
329                nonce: nonce::Info::default(),
330            },
331            Receiver {
332                stream: self.stream,
333                max_message_size: self.max_message_size,
334                cipher: self.cipher_recv,
335                nonce: nonce::Info::default(),
336            },
337        )
338    }
339}
340
341/// The half of the `Connection` that implements `crate::Sender`.
342pub struct Sender<Si: Sink> {
343    sink: Si,
344    max_message_size: usize,
345    cipher: ChaCha20Poly1305,
346    nonce: nonce::Info,
347}
348
349impl<Si: Sink> crate::Sender for Sender<Si> {
350    async fn send(&mut self, msg: &[u8]) -> Result<(), Error> {
351        // Encrypt data
352        let nonce = self.nonce.next()?;
353        let msg = self
354            .cipher
355            .encrypt(&nonce, msg.as_ref())
356            .map_err(|_| Error::EncryptionFailed)?;
357
358        // Send data
359        send_frame(
360            &mut self.sink,
361            &msg,
362            self.max_message_size + AUTHENTICATION_TAG_LENGTH,
363        )
364        .await?;
365        Ok(())
366    }
367}
368
369/// The half of a `Connection` that implements `crate::Receiver`.
370pub struct Receiver<St: Stream> {
371    stream: St,
372    max_message_size: usize,
373    cipher: ChaCha20Poly1305,
374    nonce: nonce::Info,
375}
376
377impl<St: Stream> crate::Receiver for Receiver<St> {
378    async fn receive(&mut self) -> Result<Bytes, Error> {
379        // Read data
380        let msg = recv_frame(
381            &mut self.stream,
382            self.max_message_size + AUTHENTICATION_TAG_LENGTH,
383        )
384        .await?;
385
386        // Decrypt data
387        let nonce = self.nonce.next()?;
388        self.cipher
389            .decrypt(&nonce, msg.as_ref())
390            .map(Bytes::from)
391            .map_err(|_| Error::DecryptionFailed)
392    }
393}
394
395#[cfg(test)]
396mod tests {
397    use super::*;
398    use crate::{Receiver as _, Sender as _};
399    use chacha20poly1305::KeyInit;
400    use commonware_cryptography::{
401        ed25519::{PrivateKey, PublicKey},
402        PrivateKeyExt as _,
403    };
404    use commonware_runtime::{deterministic, mocks, Metrics, Runner};
405    use std::time::Duration;
406
407    #[test]
408    fn test_decryption_failure() {
409        let executor = deterministic::Runner::default();
410        executor.start(|_| async move {
411            let cipher = ChaCha20Poly1305::new(&[0u8; 32].into());
412            let (mut sink, stream) = mocks::Channel::init();
413            let mut receiver = Receiver {
414                cipher,
415                stream,
416                max_message_size: 1024,
417                nonce: nonce::Info::default(),
418            };
419
420            // Store initial nonce value
421            let initial_nonce = receiver.nonce;
422
423            // Send invalid ciphertext
424            send_frame(&mut sink, b"invalid data", receiver.max_message_size)
425                .await
426                .unwrap();
427
428            // Attempt to receive (should fail)
429            let result = receiver.receive().await;
430            assert!(matches!(result, Err(Error::DecryptionFailed)));
431
432            // Verify nonce was incremented despite decryption failure
433            let final_nonce = receiver.nonce;
434            assert_ne!(initial_nonce, final_nonce);
435        });
436    }
437
438    #[test]
439    fn test_send_too_large() {
440        let executor = deterministic::Runner::default();
441        executor.start(|_| async move {
442            let cipher = ChaCha20Poly1305::new(&[0u8; 32].into());
443            let message = b"hello world";
444            let (sink, _) = mocks::Channel::init();
445            let mut sender = Sender {
446                cipher,
447                sink,
448                max_message_size: message.len() - 1,
449                nonce: nonce::Info::default(),
450            };
451
452            let result = sender.send(message).await;
453            let expected_length = message.len() + AUTHENTICATION_TAG_LENGTH;
454            assert!(matches!(result, Err(Error::SendTooLarge(n)) if n == expected_length));
455        });
456    }
457
458    #[test]
459    fn test_receive_too_large() {
460        let executor = deterministic::Runner::default();
461        executor.start(|_| async move {
462            let cipher = ChaCha20Poly1305::new(&[0u8; 32].into());
463            let message = b"hello world";
464            let (sink, stream) = mocks::Channel::init();
465
466            let mut sender = Sender {
467                cipher: cipher.clone(),
468                sink,
469                max_message_size: message.len(),
470                nonce: nonce::Info::default(),
471            };
472            let mut receiver = Receiver {
473                cipher,
474                stream,
475                max_message_size: message.len() - 1,
476                nonce: nonce::Info::default(),
477            };
478
479            sender.send(message).await.unwrap();
480            let result = receiver.receive().await;
481            let expected_length = message.len() + AUTHENTICATION_TAG_LENGTH;
482            assert!(matches!(result, Err(Error::RecvTooLarge(n)) if n == expected_length));
483        });
484    }
485
486    #[test]
487    fn test_send_receive() {
488        let executor = deterministic::Runner::default();
489        executor.start(|_| async move {
490            let cipher = ChaCha20Poly1305::new(&[0u8; 32].into());
491            let max_message_size = 1024;
492
493            // Create channels
494            let (dialer_sink, listener_stream) = mocks::Channel::init();
495            let (listener_sink, dialer_stream) = mocks::Channel::init();
496
497            // Create dialer connection
498            let connection_dialer = Connection::from_preestablished(
499                dialer_sink,
500                dialer_stream,
501                max_message_size,
502                cipher.clone(),
503                cipher.clone(),
504            );
505
506            // Create listener connection
507            let connection_listener = Connection::from_preestablished(
508                listener_sink,
509                listener_stream,
510                max_message_size,
511                cipher.clone(),
512                cipher,
513            );
514
515            // Split into sender and receiver for both connections
516            let (mut dialer_sender, mut dialer_receiver) = connection_dialer.split();
517            let (mut listener_sender, mut listener_receiver) = connection_listener.split();
518
519            // Test 1: Send from dialer to listener
520            let msg1 = b"hello from dialer";
521            dialer_sender.send(msg1).await.unwrap();
522            let received1 = listener_receiver.receive().await.unwrap();
523            assert_eq!(received1, &msg1[..]);
524
525            // Test 2: Send from listener to dialer
526            let msg2 = b"hello from listener";
527            listener_sender.send(msg2).await.unwrap();
528            let received2 = dialer_receiver.receive().await.unwrap();
529            assert_eq!(received2, &msg2[..]);
530
531            // Test 3: Send multiple messages both ways
532            let messages_to_listener = vec![b"msg1", b"msg2", b"msg3"];
533            for msg in &messages_to_listener {
534                dialer_sender.send(*msg).await.unwrap();
535                let received = listener_receiver.receive().await.unwrap();
536                assert_eq!(received, &msg[..]);
537            }
538            let messages_to_dialer = vec![b"reply1", b"reply2", b"reply3"];
539            for msg in &messages_to_dialer {
540                listener_sender.send(*msg).await.unwrap();
541                let received = dialer_receiver.receive().await.unwrap();
542                assert_eq!(received, &msg[..]);
543            }
544        });
545    }
546    #[test]
547    fn test_full_connection_establishment_and_exchange() {
548        let executor = deterministic::Runner::default();
549        executor.start(|context| async move {
550            // Create cryptographic identities
551            let dialer_crypto = PrivateKey::from_seed(0);
552            let listener_crypto = PrivateKey::from_seed(1);
553
554            // Set up mock channels for transport simulation
555            let (dialer_sink, listener_stream) = mocks::Channel::init();
556            let (listener_sink, dialer_stream) = mocks::Channel::init();
557
558            // Configuration for dialer
559            let dialer_config = Config {
560                crypto: dialer_crypto.clone(),
561                namespace: b"test_namespace".to_vec(),
562                max_message_size: 1024,
563                synchrony_bound: Duration::from_secs(5),
564                max_handshake_age: Duration::from_secs(5),
565                handshake_timeout: Duration::from_secs(5),
566            };
567
568            // Configuration for listener
569            let listener_config = Config {
570                crypto: listener_crypto.clone(),
571                namespace: b"test_namespace".to_vec(),
572                max_message_size: 1024,
573                synchrony_bound: Duration::from_secs(5),
574                max_handshake_age: Duration::from_secs(5),
575                handshake_timeout: Duration::from_secs(5),
576            };
577
578            // Spawn listener to handle incoming connection
579            let listener_handle = context.with_label("listener").spawn({
580                move |context| async move {
581                    let incoming = IncomingConnection::verify(
582                        &context,
583                        listener_config,
584                        listener_sink,
585                        listener_stream,
586                    )
587                    .await
588                    .unwrap();
589                    Connection::upgrade_listener(context, incoming)
590                        .await
591                        .unwrap()
592                }
593            });
594
595            // Dialer initiates the connection
596            let dialer_connection = Connection::upgrade_dialer(
597                context.clone(),
598                dialer_config,
599                dialer_sink,
600                dialer_stream,
601                listener_crypto.public_key(),
602            )
603            .await
604            .unwrap();
605
606            // Wait for listener connection to be established
607            let listener_connection = listener_handle.await.unwrap();
608
609            // Split connections into sender and receiver halves
610            let (mut dialer_sender, mut dialer_receiver) = dialer_connection.split();
611            let (mut listener_sender, mut listener_receiver) = listener_connection.split();
612
613            // Dialer sends to listener twice
614            let message1 = b"Hello from dialer";
615            dialer_sender.send(message1).await.unwrap();
616            dialer_sender.send(message1).await.unwrap();
617            let received = listener_receiver.receive().await.unwrap();
618            assert_eq!(&received[..], &message1[..]);
619            let received = listener_receiver.receive().await.unwrap();
620            assert_eq!(&received[..], &message1[..]);
621
622            // Listener sends to dialer twice
623            let message2 = b"Hello from listener";
624            listener_sender.send(message2).await.unwrap();
625            listener_sender.send(message2).await.unwrap();
626            let received = dialer_receiver.receive().await.unwrap();
627            assert_eq!(&received[..], &message2[..]);
628            let received = dialer_receiver.receive().await.unwrap();
629            assert_eq!(&received[..], &message2[..]);
630        });
631    }
632
633    #[test]
634    fn test_upgrade_dialer_wrong_peer() {
635        let executor = deterministic::Runner::default();
636        executor.start(|context| async move {
637            // Create cryptographic identities
638            let dialer_crypto = PrivateKey::from_seed(0);
639            let expected_peer = PrivateKey::from_seed(1).public_key();
640            let mut actual_peer = PrivateKey::from_seed(2);
641
642            // Set up mock channels
643            let (dialer_sink, mut peer_stream) = mocks::Channel::init();
644            let (mut peer_sink, dialer_stream) = mocks::Channel::init();
645
646            // Dialer configuration
647            let dialer_config = Config {
648                crypto: dialer_crypto,
649                namespace: b"test_namespace".to_vec(),
650                max_message_size: 1024,
651                synchrony_bound: Duration::from_secs(5),
652                max_handshake_age: Duration::from_secs(5),
653                handshake_timeout: Duration::from_secs(5),
654            };
655            let peer_config = dialer_config.clone();
656
657            // Spawn a mock peer that responds with a listener response from wrong peer
658            context.with_label("mock_peer").spawn({
659                move |mut context| async move {
660                    use chacha20poly1305::KeyInit;
661
662                    // Read the hello from dialer
663                    let msg = recv_frame(&mut peer_stream, 1024).await.unwrap();
664                    let _ = handshake::Hello::<PublicKey>::decode(msg).unwrap();
665
666                    // Create mock shared secret and cipher for `confirmation`
667                    let mock_secret = [1u8; 32];
668                    let mock_cipher = ChaCha20Poly1305::new(&mock_secret.into());
669
670                    // Create and send own hello as listener response
671                    let secret = x25519::new(&mut context);
672                    let timestamp = context.current().epoch_millis();
673                    let info = handshake::Info::new(
674                        peer_config.crypto.public_key(),
675                        x25519::PublicKey::from_secret(&secret),
676                        timestamp,
677                    );
678                    let hello =
679                        handshake::Hello::sign(&mut actual_peer, &peer_config.namespace, info);
680
681                    // Create fake `confirmation` (using fake transcript)
682                    let fake_transcript = b"fake_transcript_data";
683                    let confirmation = Confirmation::create(mock_cipher, fake_transcript).unwrap();
684
685                    send_frame(&mut peer_sink, &(hello, confirmation).encode(), 1024)
686                        .await
687                        .unwrap();
688                }
689            });
690
691            // Attempt connection with expected peer key
692            let result = Connection::upgrade_dialer(
693                context,
694                dialer_config,
695                dialer_sink,
696                dialer_stream,
697                expected_peer,
698            )
699            .await;
700
701            // Verify the error
702            assert!(matches!(result, Err(Error::WrongPeer)));
703        });
704    }
705
706    #[test]
707    fn test_upgrade_dialer_non_contributory_secret() {
708        let executor = deterministic::Runner::default();
709        executor.start(|context| async move {
710            // Create cryptographic identities
711            let dialer_crypto = PrivateKey::from_seed(0);
712            let mut listener_crypto = PrivateKey::from_seed(1);
713            let listener_public_key = listener_crypto.public_key();
714
715            // Set up mock channels
716            let (dialer_sink, mut peer_stream) = mocks::Channel::init();
717            let (mut peer_sink, dialer_stream) = mocks::Channel::init();
718
719            // Dialer configuration
720            let dialer_config = Config {
721                crypto: dialer_crypto,
722                namespace: b"test_namespace".to_vec(),
723                max_message_size: 1024,
724                synchrony_bound: Duration::from_secs(5),
725                max_handshake_age: Duration::from_secs(5),
726                handshake_timeout: Duration::from_secs(5),
727            };
728
729            // Spawn a mock peer that responds with a listener response containing an all-zero ephemeral key
730            context.with_label("mock_peer").spawn({
731                let namespace = dialer_config.namespace.clone();
732                let recipient_pk = dialer_config.crypto.public_key();
733                move |context| async move {
734                    use chacha20poly1305::KeyInit;
735
736                    // Read the hello from dialer
737                    let msg = recv_frame(&mut peer_stream, 1024).await.unwrap();
738                    let _ = handshake::Hello::<PublicKey>::decode(msg).unwrap();
739
740                    // Create mock cipher for `confirmation`
741                    let mock_secret = [1u8; 32];
742                    let mock_cipher = ChaCha20Poly1305::new(&mock_secret.into());
743
744                    // Create a custom hello info bytes with zero ephemeral key
745                    let timestamp = context.current().epoch_millis();
746                    let info = handshake::Info::new(
747                        recipient_pk,
748                        x25519::PublicKey::from_bytes([0u8; 32]),
749                        timestamp,
750                    );
751
752                    // Create the signed `hello`
753                    let hello = handshake::Hello::sign(&mut listener_crypto, &namespace, info);
754
755                    // Create fake listener response (using fake transcript)
756                    let fake_transcript = b"fake_transcript_for_non_contributory_test";
757                    let confirmation = Confirmation::create(mock_cipher, fake_transcript).unwrap();
758
759                    // Send the listener response
760                    send_frame(&mut peer_sink, &(hello, confirmation).encode(), 1024)
761                        .await
762                        .unwrap();
763                }
764            });
765
766            // Attempt connection - should fail due to non-contributory shared secret
767            let result = Connection::upgrade_dialer(
768                context,
769                dialer_config,
770                dialer_sink,
771                dialer_stream,
772                listener_public_key,
773            )
774            .await;
775
776            // Verify the error
777            assert!(matches!(result, Err(Error::SharedSecretNotContributory)));
778        });
779    }
780
781    #[test]
782    fn test_upgrade_listener_non_contributory_secret() {
783        let executor = deterministic::Runner::default();
784        executor.start(|context| async move {
785            // Create cryptographic identities
786            let mut dialer_crypto = PrivateKey::from_seed(0);
787            let listener_crypto = PrivateKey::from_seed(1);
788
789            // Set up mock channels
790            let (mut dialer_sink, listener_stream) = mocks::Channel::init();
791            let (listener_sink, _dialer_stream) = mocks::Channel::init();
792
793            // Listener configuration
794            let listener_config = Config {
795                crypto: listener_crypto.clone(),
796                namespace: b"test_namespace".to_vec(),
797                max_message_size: 1024,
798                synchrony_bound: Duration::from_secs(5),
799                max_handshake_age: Duration::from_secs(5),
800                handshake_timeout: Duration::from_secs(5),
801            };
802
803            // Encode all-zero ephemeral public key (32 bytes)
804            let info = handshake::Info::new(
805                listener_config.crypto.public_key(),
806                x25519::PublicKey::from_bytes([0u8; 32]),
807                context.current().epoch_millis(),
808            );
809
810            // Create the signed hello
811            let hello =
812                handshake::Hello::sign(&mut dialer_crypto, &listener_config.namespace, info);
813
814            // Send the hello
815            send_frame(&mut dialer_sink, &hello.encode(), 1024)
816                .await
817                .unwrap();
818
819            // Verify the incoming connection
820            let incoming = IncomingConnection::verify(
821                &context,
822                listener_config,
823                listener_sink,
824                listener_stream,
825            )
826            .await
827            .unwrap();
828
829            // Attempt to upgrade - should fail due to non-contributory shared secret
830            let result = Connection::upgrade_listener(context, incoming).await;
831
832            // Verify the error
833            assert!(matches!(result, Err(Error::SharedSecretNotContributory)));
834        });
835    }
836
837    #[test]
838    fn test_listener_rejects_hello_signed_with_own_key() {
839        let executor = deterministic::Runner::default();
840        executor.start(|context| async move {
841            let self_crypto = PrivateKey::from_seed(0);
842            let self_public_key = self_crypto.public_key();
843
844            let config = Config {
845                crypto: self_crypto.clone(),
846                namespace: b"test_self_connect_namespace".to_vec(),
847                max_message_size: 1024,
848                synchrony_bound: Duration::from_secs(5),
849                max_handshake_age: Duration::from_secs(5),
850                handshake_timeout: Duration::from_secs(1),
851            };
852
853            // Initial hello travels: dialer_sink -> listener_stream
854            let (mut dialer_sink, listener_stream) = mocks::Channel::init();
855            // Reply hello would travel: listener_reply_sink -> dialer_stream
856            let (listener_reply_sink, _dialer_stream) = mocks::Channel::init();
857
858            let listener_config = config.clone();
859            let listener_handle =
860                context
861                    .with_label("self_listener")
862                    .spawn(move |task_ctx| async move {
863                        IncomingConnection::verify(
864                            &task_ctx,
865                            listener_config,
866                            listener_reply_sink,
867                            listener_stream,
868                        )
869                        .await
870                    });
871
872            let max_msg_size = config.max_message_size;
873            let namespace = config.namespace.clone();
874            let handshake_sender_handle =
875                context
876                    .with_label("handshake_sender")
877                    .spawn(move |task_ctx| {
878                        let mut crypto_for_signing = self_crypto.clone();
879                        let recipient_pk = self_public_key.clone();
880                        let ephemeral_pk = super::x25519::PublicKey::from_bytes([0xCDu8; 32]);
881
882                        async move {
883                            let timestamp = task_ctx.current().epoch_millis();
884                            let info =
885                                super::handshake::Info::new(recipient_pk, ephemeral_pk, timestamp);
886                            let hello = super::handshake::Hello::sign(
887                                &mut crypto_for_signing,
888                                &namespace,
889                                info,
890                            );
891                            crate::utils::codec::send_frame(
892                                &mut dialer_sink,
893                                &hello.encode(),
894                                max_msg_size,
895                            )
896                            .await
897                        }
898                    });
899
900            // Ensure hello is sent
901            handshake_sender_handle.await.unwrap().unwrap();
902
903            let listener_result = listener_handle.await.unwrap();
904            assert!(matches!(listener_result, Err(Error::HelloUsesOurKey)));
905        });
906    }
907
908    #[test]
909    fn test_three_message_handshake_protocol() {
910        let executor = deterministic::Runner::default();
911        executor.start(|context| async move {
912            // Create cryptographic identities
913            let dialer_crypto = PrivateKey::from_seed(0);
914            let listener_crypto = PrivateKey::from_seed(1);
915
916            // Set up mock channels for transport simulation
917            let (dialer_sink, listener_stream) = mocks::Channel::init();
918            let (listener_sink, dialer_stream) = mocks::Channel::init();
919
920            // Configuration for dialer
921            let dialer_config = Config {
922                crypto: dialer_crypto.clone(),
923                namespace: b"test_3msg_namespace".to_vec(),
924                max_message_size: 1024,
925                synchrony_bound: Duration::from_secs(5),
926                max_handshake_age: Duration::from_secs(5),
927                handshake_timeout: Duration::from_secs(5),
928            };
929
930            // Configuration for listener
931            let listener_config = Config {
932                crypto: listener_crypto.clone(),
933                namespace: b"test_3msg_namespace".to_vec(),
934                max_message_size: 1024,
935                synchrony_bound: Duration::from_secs(5),
936                max_handshake_age: Duration::from_secs(5),
937                handshake_timeout: Duration::from_secs(5),
938            };
939
940            // Spawn listener to handle incoming connection
941            let listener_handle = context.with_label("listener").spawn({
942                move |context| async move {
943                    let incoming = IncomingConnection::verify(
944                        &context,
945                        listener_config,
946                        listener_sink,
947                        listener_stream,
948                    )
949                    .await
950                    .unwrap();
951                    Connection::upgrade_listener(context, incoming)
952                        .await
953                        .unwrap()
954                }
955            });
956
957            // Dialer initiates the connection
958            let dialer_connection = Connection::upgrade_dialer(
959                context.clone(),
960                dialer_config,
961                dialer_sink,
962                dialer_stream,
963                listener_crypto.public_key(),
964            )
965            .await
966            .unwrap();
967
968            // Wait for listener connection to be established
969            let listener_connection = listener_handle.await.unwrap();
970
971            // Split connections into sender and receiver halves
972            let (mut dialer_sender, mut dialer_receiver) = dialer_connection.split();
973            let (mut listener_sender, mut listener_receiver) = listener_connection.split();
974
975            // Test message exchange after successful 3-message handshake
976            let message1 = b"Hello from dialer after 3-msg handshake";
977            dialer_sender.send(message1).await.unwrap();
978            let received = listener_receiver.receive().await.unwrap();
979            assert_eq!(&received[..], &message1[..]);
980
981            let message2 = b"Hello from listener after 3-msg handshake";
982            listener_sender.send(message2).await.unwrap();
983            let received = dialer_receiver.receive().await.unwrap();
984            assert_eq!(&received[..], &message2[..]);
985        });
986    }
987
988    #[test]
989    fn test_upgrade_dialer_rejects_connecting_to_self() {
990        let executor = deterministic::Runner::default();
991        executor.start(|context| async move {
992            // Create cryptographic identity.
993            let self_crypto = PrivateKey::from_seed(0);
994            let self_public_key = self_crypto.public_key();
995
996            // Configure dialer parameters.
997            let dialer_config = Config {
998                crypto: self_crypto.clone(),
999                namespace: b"test_dial_self_direct".to_vec(),
1000                max_message_size: 1024,
1001                synchrony_bound: Duration::from_secs(5),
1002                max_handshake_age: Duration::from_secs(5),
1003                handshake_timeout: Duration::from_secs(1),
1004            };
1005
1006            // Set up mock channels (not fully utilized due to early error).
1007            let (dialer_sink, _unused_stream) = mocks::Channel::init();
1008            let (_unused_sink, dialer_stream) = mocks::Channel::init();
1009
1010            // Attempt to upgrade dialer connection, targeting self.
1011            let result = Connection::upgrade_dialer(
1012                context.clone(),
1013                dialer_config,
1014                dialer_sink,
1015                dialer_stream,
1016                self_public_key.clone(),
1017            )
1018            .await;
1019
1020            // Verify dialer rejects self-connection attempt.
1021            assert!(matches!(result, Err(Error::DialSelf)));
1022        });
1023    }
1024}