commonware_stream/public_key/
connection.rs

1use super::{handshake, nonce, x25519, Config};
2use crate::{
3    utils::codec::{recv_frame, send_frame},
4    Error,
5};
6use bytes::Bytes;
7use chacha20poly1305::{
8    aead::{Aead, KeyInit},
9    ChaCha20Poly1305,
10};
11use commonware_codec::{DecodeExt, Encode};
12use commonware_cryptography::Scheme;
13use commonware_macros::select;
14use commonware_runtime::{Clock, Sink, Spawner, Stream};
15use commonware_utils::SystemTimeExt as _;
16use rand::{CryptoRng, Rng};
17use std::time::SystemTime;
18
19// When encrypting data, an encryption tag is appended to the ciphertext.
20// This constant represents the size of the encryption tag in bytes.
21const ENCRYPTION_TAG_LENGTH: usize = 16;
22
23/// An incoming connection with a verified peer handshake.
24pub struct IncomingConnection<C: Scheme, Si: Sink, St: Stream> {
25    config: Config<C>,
26    sink: Si,
27    stream: St,
28    deadline: SystemTime,
29    ephemeral_public_key: x25519::PublicKey,
30    peer_public_key: C::PublicKey,
31}
32
33impl<C: Scheme, Si: Sink, St: Stream> IncomingConnection<C, Si, St> {
34    pub async fn verify<E: Clock + Spawner>(
35        context: &E,
36        config: Config<C>,
37        sink: Si,
38        mut stream: St,
39    ) -> Result<Self, Error> {
40        // Set handshake deadline
41        let deadline = context.current() + config.handshake_timeout;
42
43        // Wait for up to handshake timeout for response
44        let msg = select! {
45            _ = context.sleep_until(deadline) => { return Err(Error::HandshakeTimeout) },
46            result = recv_frame(&mut stream, config.max_message_size) => { result? },
47        };
48
49        // Verify handshake message from peer
50        let signed_handshake =
51            handshake::Signed::<C>::decode(msg).map_err(Error::UnableToDecode)?;
52        signed_handshake.verify(
53            context,
54            &config.crypto,
55            &config.namespace,
56            config.synchrony_bound,
57            config.max_handshake_age,
58        )?;
59        Ok(Self {
60            config,
61            sink,
62            stream,
63            deadline,
64            ephemeral_public_key: signed_handshake.ephemeral(),
65            peer_public_key: signed_handshake.signer(),
66        })
67    }
68
69    /// The public key of the peer attempting to connect.
70    pub fn peer(&self) -> C::PublicKey {
71        self.peer_public_key.clone()
72    }
73
74    /// The ephemeral public key of the peer attempting to connect.
75    pub fn ephemeral(&self) -> x25519::PublicKey {
76        self.ephemeral_public_key
77    }
78}
79
80/// A fully initialized connection with some peer.
81pub struct Connection<Si: Sink, St: Stream> {
82    dialer: bool,
83    sink: Si,
84    stream: St,
85    cipher: ChaCha20Poly1305,
86    max_message_size: usize,
87}
88
89impl<Si: Sink, St: Stream> Connection<Si, St> {
90    /// Create a new connection from pre-established components.
91    ///
92    /// This is useful in tests, or when upgrading a connection that has already been verified.
93    pub fn from_preestablished(
94        dialer: bool,
95        sink: Si,
96        stream: St,
97        cipher: ChaCha20Poly1305,
98        max_message_size: usize,
99    ) -> Self {
100        Self {
101            dialer,
102            sink,
103            stream,
104            cipher,
105            max_message_size,
106        }
107    }
108
109    /// Attempt to upgrade a raw connection we initiated.
110    ///
111    /// This will send a handshake message to the peer, wait for a response,
112    /// and verify the peer's handshake message.
113    pub async fn upgrade_dialer<R: Rng + CryptoRng + Spawner + Clock, C: Scheme>(
114        mut context: R,
115        mut config: Config<C>,
116        mut sink: Si,
117        mut stream: St,
118        peer: C::PublicKey,
119    ) -> Result<Self, Error> {
120        // Set handshake deadline
121        let deadline = context.current() + config.handshake_timeout;
122
123        // Generate shared secret
124        let secret = x25519::new(&mut context);
125
126        // Send handshake
127        let timestamp = context.current().epoch_millis();
128        let msg = handshake::Signed::sign(
129            &mut config.crypto,
130            &config.namespace,
131            handshake::Info::<C>::new(peer.clone(), &secret, timestamp),
132        )
133        .encode();
134
135        // Wait for up to handshake timeout to send
136        select! {
137            _ = context.sleep_until(deadline) => {
138                return Err(Error::HandshakeTimeout)
139            },
140            result = send_frame(&mut sink, &msg, config.max_message_size) => {
141                result?;
142            },
143        }
144
145        // Wait for up to handshake timeout for response
146        let msg = select! {
147            _ = context.sleep_until(deadline) => {
148                return Err(Error::HandshakeTimeout)
149            },
150            result = recv_frame(&mut stream, config.max_message_size) => {
151                result?
152            },
153        };
154
155        // Verify handshake message from peer
156        let signed_handshake =
157            handshake::Signed::<C>::decode(msg).map_err(Error::UnableToDecode)?;
158        signed_handshake.verify(
159            &context,
160            &config.crypto,
161            &config.namespace,
162            config.synchrony_bound,
163            config.max_handshake_age,
164        )?;
165
166        // Ensure we connected to the right peer
167        if peer != signed_handshake.signer() {
168            return Err(Error::WrongPeer);
169        }
170
171        // Create cipher
172        let shared_secret = secret.diffie_hellman(signed_handshake.ephemeral().as_ref());
173        let cipher = ChaCha20Poly1305::new_from_slice(shared_secret.as_bytes())
174            .map_err(|_| Error::CipherCreationFailed)?;
175
176        // We keep track of dialer to determine who adds a bit to their nonce (to prevent reuse)
177        Ok(Self {
178            dialer: true,
179            sink,
180            stream,
181            cipher,
182            max_message_size: config.max_message_size,
183        })
184    }
185
186    /// Attempt to upgrade a connection initiated by some peer.
187    ///
188    /// Because we already verified the peer's handshake, this function
189    /// only needs to send our handshake message for the connection to be fully
190    /// initialized.
191    pub async fn upgrade_listener<R: Rng + CryptoRng + Spawner + Clock, C: Scheme>(
192        mut context: R,
193        incoming: IncomingConnection<C, Si, St>,
194    ) -> Result<Self, Error> {
195        // Extract fields
196        let max_message_size = incoming.config.max_message_size;
197        let mut crypto = incoming.config.crypto;
198        let namespace = incoming.config.namespace;
199        let mut sink = incoming.sink;
200        let stream = incoming.stream;
201
202        // Generate personal secret
203        let secret = x25519::new(&mut context);
204
205        // Send handshake
206        let timestamp = context.current().epoch_millis();
207        let msg = handshake::Signed::sign(
208            &mut crypto,
209            &namespace,
210            handshake::Info::<C>::new(incoming.peer_public_key, &secret, timestamp),
211        )
212        .encode();
213
214        // Wait for up to handshake timeout
215        select! {
216            _ = context.sleep_until(incoming.deadline) => {
217                return Err(Error::HandshakeTimeout)
218            },
219            result = send_frame(&mut sink, &msg, max_message_size) => {
220                result?;
221            },
222        }
223
224        // Create cipher based on the shared secret
225        let shared_secret = secret.diffie_hellman(incoming.ephemeral_public_key.as_ref());
226        let cipher = ChaCha20Poly1305::new_from_slice(shared_secret.as_bytes())
227            .map_err(|_| Error::CipherCreationFailed)?;
228
229        // Track whether or not we are the dialer to ensure we send correctly formatted nonces.
230        Ok(Connection {
231            dialer: false,
232            sink,
233            stream,
234            cipher,
235            max_message_size,
236        })
237    }
238
239    /// Split the connection into a `Sender` and `Receiver`.
240    ///
241    /// This pattern is commonly used to efficiently send and receive messages
242    /// over the same connection concurrently.
243    pub fn split(self) -> (Sender<Si>, Receiver<St>) {
244        (
245            Sender {
246                cipher: self.cipher.clone(),
247                sink: self.sink,
248                max_message_size: self.max_message_size,
249                nonce: nonce::Info::new(self.dialer),
250            },
251            Receiver {
252                cipher: self.cipher,
253                stream: self.stream,
254                max_message_size: self.max_message_size,
255                nonce: nonce::Info::new(!self.dialer),
256            },
257        )
258    }
259}
260
261/// The half of the `Connection` that implements `crate::Sender`.
262pub struct Sender<Si: Sink> {
263    cipher: ChaCha20Poly1305,
264    sink: Si,
265
266    max_message_size: usize,
267    nonce: nonce::Info,
268}
269
270impl<Si: Sink> crate::Sender for Sender<Si> {
271    async fn send(&mut self, msg: &[u8]) -> Result<(), Error> {
272        // Encrypt data
273        let msg = self
274            .cipher
275            .encrypt(&self.nonce.encode(), msg.as_ref())
276            .map_err(|_| Error::EncryptionFailed)?;
277        self.nonce.inc()?;
278
279        // Send data
280        send_frame(
281            &mut self.sink,
282            &msg,
283            self.max_message_size + ENCRYPTION_TAG_LENGTH,
284        )
285        .await?;
286        Ok(())
287    }
288}
289
290/// The half of a `Connection` that implements `crate::Receiver`.
291pub struct Receiver<St: Stream> {
292    cipher: ChaCha20Poly1305,
293    stream: St,
294
295    max_message_size: usize,
296    nonce: nonce::Info,
297}
298
299impl<St: Stream> crate::Receiver for Receiver<St> {
300    async fn receive(&mut self) -> Result<Bytes, Error> {
301        // Read data
302        let msg = recv_frame(
303            &mut self.stream,
304            self.max_message_size + ENCRYPTION_TAG_LENGTH,
305        )
306        .await?;
307
308        // Decrypt data
309        let msg = self
310            .cipher
311            .decrypt(&self.nonce.encode(), msg.as_ref())
312            .map_err(|_| Error::DecryptionFailed)?;
313        self.nonce.inc()?;
314
315        Ok(Bytes::from(msg))
316    }
317}
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322    use crate::{Receiver as _, Sender as _};
323    use commonware_cryptography::{Ed25519, Signer};
324    use commonware_runtime::{deterministic, mocks, Metrics, Runner};
325    use std::time::Duration;
326
327    #[test]
328    fn test_decryption_failure() {
329        let executor = deterministic::Runner::default();
330        executor.start(|_| async move {
331            let cipher = ChaCha20Poly1305::new(&[0u8; 32].into());
332            let (mut sink, stream) = mocks::Channel::init();
333            let mut receiver = Receiver {
334                cipher,
335                stream,
336                max_message_size: 1024,
337                nonce: nonce::Info::new(false),
338            };
339
340            // Send invalid ciphertext
341            send_frame(&mut sink, b"invalid data", receiver.max_message_size)
342                .await
343                .unwrap();
344
345            let result = receiver.receive().await;
346            assert!(matches!(result, Err(Error::DecryptionFailed)));
347        });
348    }
349
350    #[test]
351    fn test_send_too_large() {
352        let executor = deterministic::Runner::default();
353        executor.start(|_| async move {
354            let cipher = ChaCha20Poly1305::new(&[0u8; 32].into());
355            let message = b"hello world";
356            let (sink, _) = mocks::Channel::init();
357            let mut sender = Sender {
358                cipher,
359                sink,
360                max_message_size: message.len() - 1,
361                nonce: nonce::Info::new(true),
362            };
363
364            let result = sender.send(message).await;
365            let expected_length = message.len() + ENCRYPTION_TAG_LENGTH;
366            assert!(matches!(result, Err(Error::SendTooLarge(n)) if n == expected_length));
367        });
368    }
369
370    #[test]
371    fn test_receive_too_large() {
372        let executor = deterministic::Runner::default();
373        executor.start(|_| async move {
374            let cipher = ChaCha20Poly1305::new(&[0u8; 32].into());
375            let message = b"hello world";
376            let (sink, stream) = mocks::Channel::init();
377
378            let mut sender = Sender {
379                cipher: cipher.clone(),
380                sink,
381                max_message_size: message.len(),
382                nonce: nonce::Info::new(true),
383            };
384            let mut receiver = Receiver {
385                cipher,
386                stream,
387                max_message_size: message.len() - 1,
388                nonce: nonce::Info::new(false),
389            };
390
391            sender.send(message).await.unwrap();
392            let result = receiver.receive().await;
393            let expected_length = message.len() + ENCRYPTION_TAG_LENGTH;
394            assert!(matches!(result, Err(Error::RecvTooLarge(n)) if n == expected_length));
395        });
396    }
397
398    #[test]
399    fn test_send_receive() {
400        let executor = deterministic::Runner::default();
401        executor.start(|_| async move {
402            let cipher = ChaCha20Poly1305::new(&[0u8; 32].into());
403            let max_message_size = 1024;
404
405            // Create channels
406            let (dialer_sink, listener_stream) = mocks::Channel::init();
407            let (listener_sink, dialer_stream) = mocks::Channel::init();
408
409            // Create dialer connection
410            let connection_dialer = Connection::from_preestablished(
411                true, // dialer
412                dialer_sink,
413                dialer_stream,
414                cipher.clone(),
415                max_message_size,
416            );
417
418            // Create listener connection
419            let connection_listener = Connection::from_preestablished(
420                false, // listener
421                listener_sink,
422                listener_stream,
423                cipher,
424                max_message_size,
425            );
426
427            // Split into sender and receiver for both connections
428            let (mut dialer_sender, mut dialer_receiver) = connection_dialer.split();
429            let (mut listener_sender, mut listener_receiver) = connection_listener.split();
430
431            // Test 1: Send from dialer to listener
432            let msg1 = b"hello from dialer";
433            dialer_sender.send(msg1).await.unwrap();
434            let received1 = listener_receiver.receive().await.unwrap();
435            assert_eq!(received1, &msg1[..]);
436
437            // Test 2: Send from listener to dialer
438            let msg2 = b"hello from listener";
439            listener_sender.send(msg2).await.unwrap();
440            let received2 = dialer_receiver.receive().await.unwrap();
441            assert_eq!(received2, &msg2[..]);
442
443            // Test 3: Send multiple messages both ways
444            let messages_to_listener = vec![b"msg1", b"msg2", b"msg3"];
445            for msg in &messages_to_listener {
446                dialer_sender.send(*msg).await.unwrap();
447                let received = listener_receiver.receive().await.unwrap();
448                assert_eq!(received, &msg[..]);
449            }
450            let messages_to_dialer = vec![b"reply1", b"reply2", b"reply3"];
451            for msg in &messages_to_dialer {
452                listener_sender.send(*msg).await.unwrap();
453                let received = dialer_receiver.receive().await.unwrap();
454                assert_eq!(received, &msg[..]);
455            }
456        });
457    }
458    #[test]
459    fn test_full_connection_establishment_and_exchange() {
460        let executor = deterministic::Runner::default();
461        executor.start(|context| async move {
462            // Create cryptographic identities
463            let dialer_crypto = Ed25519::from_seed(0);
464            let listener_crypto = Ed25519::from_seed(1);
465
466            // Set up mock channels for transport simulation
467            let (dialer_sink, listener_stream) = mocks::Channel::init();
468            let (listener_sink, dialer_stream) = mocks::Channel::init();
469
470            // Configuration for dialer
471            let dialer_config = Config {
472                crypto: dialer_crypto.clone(),
473                namespace: b"test_namespace".to_vec(),
474                max_message_size: 1024,
475                synchrony_bound: Duration::from_secs(5),
476                max_handshake_age: Duration::from_secs(5),
477                handshake_timeout: Duration::from_secs(5),
478            };
479
480            // Configuration for listener
481            let listener_config = Config {
482                crypto: listener_crypto.clone(),
483                namespace: b"test_namespace".to_vec(),
484                max_message_size: 1024,
485                synchrony_bound: Duration::from_secs(5),
486                max_handshake_age: Duration::from_secs(5),
487                handshake_timeout: Duration::from_secs(5),
488            };
489
490            // Spawn listener to handle incoming connection
491            let listener_handle = context.with_label("listener").spawn({
492                move |context| async move {
493                    let incoming = IncomingConnection::verify(
494                        &context,
495                        listener_config,
496                        listener_sink,
497                        listener_stream,
498                    )
499                    .await
500                    .unwrap();
501                    Connection::upgrade_listener(context, incoming)
502                        .await
503                        .unwrap()
504                }
505            });
506
507            // Dialer initiates the connection
508            let dialer_connection = Connection::upgrade_dialer(
509                context.clone(),
510                dialer_config,
511                dialer_sink,
512                dialer_stream,
513                listener_crypto.public_key(),
514            )
515            .await
516            .unwrap();
517
518            // Wait for listener connection to be established
519            let listener_connection = listener_handle.await.unwrap();
520
521            // Split connections into sender and receiver halves
522            let (mut dialer_sender, mut dialer_receiver) = dialer_connection.split();
523            let (mut listener_sender, mut listener_receiver) = listener_connection.split();
524
525            // Dialer sends to listener twice
526            let message1 = b"Hello from dialer";
527            dialer_sender.send(message1).await.unwrap();
528            dialer_sender.send(message1).await.unwrap();
529            let received = listener_receiver.receive().await.unwrap();
530            assert_eq!(&received[..], &message1[..]);
531            let received = listener_receiver.receive().await.unwrap();
532            assert_eq!(&received[..], &message1[..]);
533
534            // Listener sends to dialer twice
535            let message2 = b"Hello from listener";
536            listener_sender.send(message2).await.unwrap();
537            listener_sender.send(message2).await.unwrap();
538            let received = dialer_receiver.receive().await.unwrap();
539            assert_eq!(&received[..], &message2[..]);
540            let received = dialer_receiver.receive().await.unwrap();
541            assert_eq!(&received[..], &message2[..]);
542        });
543    }
544
545    #[test]
546    fn test_upgrade_dialer_wrong_peer() {
547        let executor = deterministic::Runner::default();
548        executor.start(|context| async move {
549            // Create cryptographic identities
550            let dialer_crypto = Ed25519::from_seed(0);
551            let expected_peer = Ed25519::from_seed(1).public_key();
552            let mut actual_peer = Ed25519::from_seed(2);
553
554            // Set up mock channels
555            let (dialer_sink, mut peer_stream) = mocks::Channel::init();
556            let (mut peer_sink, dialer_stream) = mocks::Channel::init();
557
558            // Dialer configuration
559            let dialer_config = Config {
560                crypto: dialer_crypto,
561                namespace: b"test_namespace".to_vec(),
562                max_message_size: 1024,
563                synchrony_bound: Duration::from_secs(5),
564                max_handshake_age: Duration::from_secs(5),
565                handshake_timeout: Duration::from_secs(5),
566            };
567            let peer_config = dialer_config.clone();
568
569            // Spawn a mock peer that responds with its own handshake without checking recipient
570            context.with_label("mock_peer").spawn({
571                move |mut context| async move {
572                    // Read the handshake from dialer
573                    let msg = recv_frame(&mut peer_stream, 1024).await.unwrap();
574                    let _ = handshake::Signed::<Ed25519>::decode(msg).unwrap(); // Simulate reading
575
576                    // Create and send own handshake
577                    let secret = x25519::new(&mut context);
578                    let timestamp = context.current().epoch_millis();
579                    let info =
580                        handshake::Info::new(peer_config.crypto.public_key(), &secret, timestamp);
581                    let signed_handshake =
582                        handshake::Signed::sign(&mut actual_peer, &peer_config.namespace, info);
583                    send_frame(&mut peer_sink, &signed_handshake.encode(), 1024)
584                        .await
585                        .unwrap();
586                }
587            });
588
589            // Attempt connection with expected peer key
590            let result = Connection::upgrade_dialer(
591                context,
592                dialer_config,
593                dialer_sink,
594                dialer_stream,
595                expected_peer,
596            )
597            .await;
598
599            // Verify the error
600            assert!(matches!(result, Err(Error::WrongPeer)));
601        });
602    }
603}