1use super::{handshake, nonce, x25519, Config};
2use crate::{
3 utils::codec::{recv_frame, send_frame},
4 Error,
5};
6use bytes::Bytes;
7use chacha20poly1305::{
8 aead::{Aead, KeyInit},
9 ChaCha20Poly1305,
10};
11use commonware_codec::{DecodeExt, Encode};
12use commonware_cryptography::Scheme;
13use commonware_macros::select;
14use commonware_runtime::{Clock, Sink, Spawner, Stream};
15use commonware_utils::SystemTimeExt as _;
16use rand::{CryptoRng, Rng};
17use std::time::SystemTime;
18
19const ENCRYPTION_TAG_LENGTH: usize = 16;
22
23pub struct IncomingConnection<C: Scheme, Si: Sink, St: Stream> {
25 config: Config<C>,
26 sink: Si,
27 stream: St,
28 deadline: SystemTime,
29 ephemeral_public_key: x25519::PublicKey,
30 peer_public_key: C::PublicKey,
31}
32
33impl<C: Scheme, Si: Sink, St: Stream> IncomingConnection<C, Si, St> {
34 pub async fn verify<E: Clock + Spawner>(
35 context: &E,
36 config: Config<C>,
37 sink: Si,
38 mut stream: St,
39 ) -> Result<Self, Error> {
40 let deadline = context.current() + config.handshake_timeout;
42
43 let msg = select! {
45 _ = context.sleep_until(deadline) => { return Err(Error::HandshakeTimeout) },
46 result = recv_frame(&mut stream, config.max_message_size) => { result? },
47 };
48
49 let signed_handshake =
51 handshake::Signed::<C>::decode(msg).map_err(Error::UnableToDecode)?;
52 signed_handshake.verify(
53 context,
54 &config.crypto,
55 &config.namespace,
56 config.synchrony_bound,
57 config.max_handshake_age,
58 )?;
59 Ok(Self {
60 config,
61 sink,
62 stream,
63 deadline,
64 ephemeral_public_key: signed_handshake.ephemeral(),
65 peer_public_key: signed_handshake.signer(),
66 })
67 }
68
69 pub fn peer(&self) -> C::PublicKey {
71 self.peer_public_key.clone()
72 }
73
74 pub fn ephemeral(&self) -> x25519::PublicKey {
76 self.ephemeral_public_key
77 }
78}
79
80pub struct Connection<Si: Sink, St: Stream> {
82 dialer: bool,
83 sink: Si,
84 stream: St,
85 cipher: ChaCha20Poly1305,
86 max_message_size: usize,
87}
88
89impl<Si: Sink, St: Stream> Connection<Si, St> {
90 pub fn from_preestablished(
94 dialer: bool,
95 sink: Si,
96 stream: St,
97 cipher: ChaCha20Poly1305,
98 max_message_size: usize,
99 ) -> Self {
100 Self {
101 dialer,
102 sink,
103 stream,
104 cipher,
105 max_message_size,
106 }
107 }
108
109 pub async fn upgrade_dialer<R: Rng + CryptoRng + Spawner + Clock, C: Scheme>(
114 mut context: R,
115 mut config: Config<C>,
116 mut sink: Si,
117 mut stream: St,
118 peer: C::PublicKey,
119 ) -> Result<Self, Error> {
120 let deadline = context.current() + config.handshake_timeout;
122
123 let secret = x25519::new(&mut context);
125
126 let timestamp = context.current().epoch_millis();
128 let msg = handshake::Signed::sign(
129 &mut config.crypto,
130 &config.namespace,
131 handshake::Info::<C>::new(peer.clone(), &secret, timestamp),
132 )
133 .encode();
134
135 select! {
137 _ = context.sleep_until(deadline) => {
138 return Err(Error::HandshakeTimeout)
139 },
140 result = send_frame(&mut sink, &msg, config.max_message_size) => {
141 result?;
142 },
143 }
144
145 let msg = select! {
147 _ = context.sleep_until(deadline) => {
148 return Err(Error::HandshakeTimeout)
149 },
150 result = recv_frame(&mut stream, config.max_message_size) => {
151 result?
152 },
153 };
154
155 let signed_handshake =
157 handshake::Signed::<C>::decode(msg).map_err(Error::UnableToDecode)?;
158 signed_handshake.verify(
159 &context,
160 &config.crypto,
161 &config.namespace,
162 config.synchrony_bound,
163 config.max_handshake_age,
164 )?;
165
166 if peer != signed_handshake.signer() {
168 return Err(Error::WrongPeer);
169 }
170
171 let shared_secret = secret.diffie_hellman(signed_handshake.ephemeral().as_ref());
173 let cipher = ChaCha20Poly1305::new_from_slice(shared_secret.as_bytes())
174 .map_err(|_| Error::CipherCreationFailed)?;
175
176 Ok(Self {
178 dialer: true,
179 sink,
180 stream,
181 cipher,
182 max_message_size: config.max_message_size,
183 })
184 }
185
186 pub async fn upgrade_listener<R: Rng + CryptoRng + Spawner + Clock, C: Scheme>(
192 mut context: R,
193 incoming: IncomingConnection<C, Si, St>,
194 ) -> Result<Self, Error> {
195 let max_message_size = incoming.config.max_message_size;
197 let mut crypto = incoming.config.crypto;
198 let namespace = incoming.config.namespace;
199 let mut sink = incoming.sink;
200 let stream = incoming.stream;
201
202 let secret = x25519::new(&mut context);
204
205 let timestamp = context.current().epoch_millis();
207 let msg = handshake::Signed::sign(
208 &mut crypto,
209 &namespace,
210 handshake::Info::<C>::new(incoming.peer_public_key, &secret, timestamp),
211 )
212 .encode();
213
214 select! {
216 _ = context.sleep_until(incoming.deadline) => {
217 return Err(Error::HandshakeTimeout)
218 },
219 result = send_frame(&mut sink, &msg, max_message_size) => {
220 result?;
221 },
222 }
223
224 let shared_secret = secret.diffie_hellman(incoming.ephemeral_public_key.as_ref());
226 let cipher = ChaCha20Poly1305::new_from_slice(shared_secret.as_bytes())
227 .map_err(|_| Error::CipherCreationFailed)?;
228
229 Ok(Connection {
231 dialer: false,
232 sink,
233 stream,
234 cipher,
235 max_message_size,
236 })
237 }
238
239 pub fn split(self) -> (Sender<Si>, Receiver<St>) {
244 (
245 Sender {
246 cipher: self.cipher.clone(),
247 sink: self.sink,
248 max_message_size: self.max_message_size,
249 nonce: nonce::Info::new(self.dialer),
250 },
251 Receiver {
252 cipher: self.cipher,
253 stream: self.stream,
254 max_message_size: self.max_message_size,
255 nonce: nonce::Info::new(!self.dialer),
256 },
257 )
258 }
259}
260
261pub struct Sender<Si: Sink> {
263 cipher: ChaCha20Poly1305,
264 sink: Si,
265
266 max_message_size: usize,
267 nonce: nonce::Info,
268}
269
270impl<Si: Sink> crate::Sender for Sender<Si> {
271 async fn send(&mut self, msg: &[u8]) -> Result<(), Error> {
272 let msg = self
274 .cipher
275 .encrypt(&self.nonce.encode(), msg.as_ref())
276 .map_err(|_| Error::EncryptionFailed)?;
277 self.nonce.inc()?;
278
279 send_frame(
281 &mut self.sink,
282 &msg,
283 self.max_message_size + ENCRYPTION_TAG_LENGTH,
284 )
285 .await?;
286 Ok(())
287 }
288}
289
290pub struct Receiver<St: Stream> {
292 cipher: ChaCha20Poly1305,
293 stream: St,
294
295 max_message_size: usize,
296 nonce: nonce::Info,
297}
298
299impl<St: Stream> crate::Receiver for Receiver<St> {
300 async fn receive(&mut self) -> Result<Bytes, Error> {
301 let msg = recv_frame(
303 &mut self.stream,
304 self.max_message_size + ENCRYPTION_TAG_LENGTH,
305 )
306 .await?;
307
308 let msg = self
310 .cipher
311 .decrypt(&self.nonce.encode(), msg.as_ref())
312 .map_err(|_| Error::DecryptionFailed)?;
313 self.nonce.inc()?;
314
315 Ok(Bytes::from(msg))
316 }
317}
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322 use crate::{Receiver as _, Sender as _};
323 use commonware_cryptography::{Ed25519, Signer};
324 use commonware_runtime::{deterministic, mocks, Metrics, Runner};
325 use std::time::Duration;
326
327 #[test]
328 fn test_decryption_failure() {
329 let executor = deterministic::Runner::default();
330 executor.start(|_| async move {
331 let cipher = ChaCha20Poly1305::new(&[0u8; 32].into());
332 let (mut sink, stream) = mocks::Channel::init();
333 let mut receiver = Receiver {
334 cipher,
335 stream,
336 max_message_size: 1024,
337 nonce: nonce::Info::new(false),
338 };
339
340 send_frame(&mut sink, b"invalid data", receiver.max_message_size)
342 .await
343 .unwrap();
344
345 let result = receiver.receive().await;
346 assert!(matches!(result, Err(Error::DecryptionFailed)));
347 });
348 }
349
350 #[test]
351 fn test_send_too_large() {
352 let executor = deterministic::Runner::default();
353 executor.start(|_| async move {
354 let cipher = ChaCha20Poly1305::new(&[0u8; 32].into());
355 let message = b"hello world";
356 let (sink, _) = mocks::Channel::init();
357 let mut sender = Sender {
358 cipher,
359 sink,
360 max_message_size: message.len() - 1,
361 nonce: nonce::Info::new(true),
362 };
363
364 let result = sender.send(message).await;
365 let expected_length = message.len() + ENCRYPTION_TAG_LENGTH;
366 assert!(matches!(result, Err(Error::SendTooLarge(n)) if n == expected_length));
367 });
368 }
369
370 #[test]
371 fn test_receive_too_large() {
372 let executor = deterministic::Runner::default();
373 executor.start(|_| async move {
374 let cipher = ChaCha20Poly1305::new(&[0u8; 32].into());
375 let message = b"hello world";
376 let (sink, stream) = mocks::Channel::init();
377
378 let mut sender = Sender {
379 cipher: cipher.clone(),
380 sink,
381 max_message_size: message.len(),
382 nonce: nonce::Info::new(true),
383 };
384 let mut receiver = Receiver {
385 cipher,
386 stream,
387 max_message_size: message.len() - 1,
388 nonce: nonce::Info::new(false),
389 };
390
391 sender.send(message).await.unwrap();
392 let result = receiver.receive().await;
393 let expected_length = message.len() + ENCRYPTION_TAG_LENGTH;
394 assert!(matches!(result, Err(Error::RecvTooLarge(n)) if n == expected_length));
395 });
396 }
397
398 #[test]
399 fn test_send_receive() {
400 let executor = deterministic::Runner::default();
401 executor.start(|_| async move {
402 let cipher = ChaCha20Poly1305::new(&[0u8; 32].into());
403 let max_message_size = 1024;
404
405 let (dialer_sink, listener_stream) = mocks::Channel::init();
407 let (listener_sink, dialer_stream) = mocks::Channel::init();
408
409 let connection_dialer = Connection::from_preestablished(
411 true, dialer_sink,
413 dialer_stream,
414 cipher.clone(),
415 max_message_size,
416 );
417
418 let connection_listener = Connection::from_preestablished(
420 false, listener_sink,
422 listener_stream,
423 cipher,
424 max_message_size,
425 );
426
427 let (mut dialer_sender, mut dialer_receiver) = connection_dialer.split();
429 let (mut listener_sender, mut listener_receiver) = connection_listener.split();
430
431 let msg1 = b"hello from dialer";
433 dialer_sender.send(msg1).await.unwrap();
434 let received1 = listener_receiver.receive().await.unwrap();
435 assert_eq!(received1, &msg1[..]);
436
437 let msg2 = b"hello from listener";
439 listener_sender.send(msg2).await.unwrap();
440 let received2 = dialer_receiver.receive().await.unwrap();
441 assert_eq!(received2, &msg2[..]);
442
443 let messages_to_listener = vec![b"msg1", b"msg2", b"msg3"];
445 for msg in &messages_to_listener {
446 dialer_sender.send(*msg).await.unwrap();
447 let received = listener_receiver.receive().await.unwrap();
448 assert_eq!(received, &msg[..]);
449 }
450 let messages_to_dialer = vec![b"reply1", b"reply2", b"reply3"];
451 for msg in &messages_to_dialer {
452 listener_sender.send(*msg).await.unwrap();
453 let received = dialer_receiver.receive().await.unwrap();
454 assert_eq!(received, &msg[..]);
455 }
456 });
457 }
458 #[test]
459 fn test_full_connection_establishment_and_exchange() {
460 let executor = deterministic::Runner::default();
461 executor.start(|context| async move {
462 let dialer_crypto = Ed25519::from_seed(0);
464 let listener_crypto = Ed25519::from_seed(1);
465
466 let (dialer_sink, listener_stream) = mocks::Channel::init();
468 let (listener_sink, dialer_stream) = mocks::Channel::init();
469
470 let dialer_config = Config {
472 crypto: dialer_crypto.clone(),
473 namespace: b"test_namespace".to_vec(),
474 max_message_size: 1024,
475 synchrony_bound: Duration::from_secs(5),
476 max_handshake_age: Duration::from_secs(5),
477 handshake_timeout: Duration::from_secs(5),
478 };
479
480 let listener_config = Config {
482 crypto: listener_crypto.clone(),
483 namespace: b"test_namespace".to_vec(),
484 max_message_size: 1024,
485 synchrony_bound: Duration::from_secs(5),
486 max_handshake_age: Duration::from_secs(5),
487 handshake_timeout: Duration::from_secs(5),
488 };
489
490 let listener_handle = context.with_label("listener").spawn({
492 move |context| async move {
493 let incoming = IncomingConnection::verify(
494 &context,
495 listener_config,
496 listener_sink,
497 listener_stream,
498 )
499 .await
500 .unwrap();
501 Connection::upgrade_listener(context, incoming)
502 .await
503 .unwrap()
504 }
505 });
506
507 let dialer_connection = Connection::upgrade_dialer(
509 context.clone(),
510 dialer_config,
511 dialer_sink,
512 dialer_stream,
513 listener_crypto.public_key(),
514 )
515 .await
516 .unwrap();
517
518 let listener_connection = listener_handle.await.unwrap();
520
521 let (mut dialer_sender, mut dialer_receiver) = dialer_connection.split();
523 let (mut listener_sender, mut listener_receiver) = listener_connection.split();
524
525 let message1 = b"Hello from dialer";
527 dialer_sender.send(message1).await.unwrap();
528 dialer_sender.send(message1).await.unwrap();
529 let received = listener_receiver.receive().await.unwrap();
530 assert_eq!(&received[..], &message1[..]);
531 let received = listener_receiver.receive().await.unwrap();
532 assert_eq!(&received[..], &message1[..]);
533
534 let message2 = b"Hello from listener";
536 listener_sender.send(message2).await.unwrap();
537 listener_sender.send(message2).await.unwrap();
538 let received = dialer_receiver.receive().await.unwrap();
539 assert_eq!(&received[..], &message2[..]);
540 let received = dialer_receiver.receive().await.unwrap();
541 assert_eq!(&received[..], &message2[..]);
542 });
543 }
544
545 #[test]
546 fn test_upgrade_dialer_wrong_peer() {
547 let executor = deterministic::Runner::default();
548 executor.start(|context| async move {
549 let dialer_crypto = Ed25519::from_seed(0);
551 let expected_peer = Ed25519::from_seed(1).public_key();
552 let mut actual_peer = Ed25519::from_seed(2);
553
554 let (dialer_sink, mut peer_stream) = mocks::Channel::init();
556 let (mut peer_sink, dialer_stream) = mocks::Channel::init();
557
558 let dialer_config = Config {
560 crypto: dialer_crypto,
561 namespace: b"test_namespace".to_vec(),
562 max_message_size: 1024,
563 synchrony_bound: Duration::from_secs(5),
564 max_handshake_age: Duration::from_secs(5),
565 handshake_timeout: Duration::from_secs(5),
566 };
567 let peer_config = dialer_config.clone();
568
569 context.with_label("mock_peer").spawn({
571 move |mut context| async move {
572 let msg = recv_frame(&mut peer_stream, 1024).await.unwrap();
574 let _ = handshake::Signed::<Ed25519>::decode(msg).unwrap(); let secret = x25519::new(&mut context);
578 let timestamp = context.current().epoch_millis();
579 let info =
580 handshake::Info::new(peer_config.crypto.public_key(), &secret, timestamp);
581 let signed_handshake =
582 handshake::Signed::sign(&mut actual_peer, &peer_config.namespace, info);
583 send_frame(&mut peer_sink, &signed_handshake.encode(), 1024)
584 .await
585 .unwrap();
586 }
587 });
588
589 let result = Connection::upgrade_dialer(
591 context,
592 dialer_config,
593 dialer_sink,
594 dialer_stream,
595 expected_peer,
596 )
597 .await;
598
599 assert!(matches!(result, Err(Error::WrongPeer)));
601 });
602 }
603}