1use super::{
2 handshake::{create_handshake, Handshake, IncomingHandshake},
3 nonce, x25519, Config,
4};
5use crate::{
6 utils::codec::{recv_frame, send_frame},
7 Error,
8};
9use bytes::Bytes;
10use chacha20poly1305::{
11 aead::{Aead, KeyInit},
12 ChaCha20Poly1305,
13};
14use commonware_cryptography::Scheme;
15use commonware_macros::select;
16use commonware_runtime::{Clock, Sink, Spawner, Stream};
17use commonware_utils::SystemTimeExt as _;
18use rand::{CryptoRng, Rng};
19
20const ENCRYPTION_TAG_LENGTH: usize = 16;
23
24pub struct IncomingConnection<C: Scheme, Si: Sink, St: Stream> {
26 config: Config<C>,
27 handshake: IncomingHandshake<Si, St, C>,
28}
29
30impl<C: Scheme, Si: Sink, St: Stream> IncomingConnection<C, Si, St> {
31 pub async fn verify<R: Rng + CryptoRng + Spawner + Clock>(
33 context: &R,
34 config: Config<C>,
35 sink: Si,
36 stream: St,
37 ) -> Result<Self, Error> {
38 let handshake = IncomingHandshake::verify(
39 context,
40 &config.crypto,
41 &config.namespace,
42 config.max_message_size,
43 config.synchrony_bound,
44 config.max_handshake_age,
45 config.handshake_timeout,
46 sink,
47 stream,
48 )
49 .await?;
50 Ok(Self { config, handshake })
51 }
52
53 pub fn peer(&self) -> C::PublicKey {
55 self.handshake.peer_public_key.clone()
56 }
57}
58
59pub struct Connection<Si: Sink, St: Stream> {
61 dialer: bool,
62 sink: Si,
63 stream: St,
64 cipher: ChaCha20Poly1305,
65 max_message_size: usize,
66}
67
68impl<Si: Sink, St: Stream> Connection<Si, St> {
69 pub fn from_preestablished(
73 dialer: bool,
74 sink: Si,
75 stream: St,
76 cipher: ChaCha20Poly1305,
77 max_message_size: usize,
78 ) -> Self {
79 Self {
80 dialer,
81 sink,
82 stream,
83 cipher,
84 max_message_size,
85 }
86 }
87
88 pub async fn upgrade_dialer<R: Rng + CryptoRng + Spawner + Clock, C: Scheme>(
93 mut context: R,
94 mut config: Config<C>,
95 mut sink: Si,
96 mut stream: St,
97 peer: C::PublicKey,
98 ) -> Result<Self, Error> {
99 let deadline = context.current() + config.handshake_timeout;
101
102 let secret = x25519::new(&mut context);
104 let ephemeral = x25519_dalek::PublicKey::from(&secret);
105
106 let timestamp = context.current().epoch_millis();
108 let msg = create_handshake(
109 &mut config.crypto,
110 &config.namespace,
111 timestamp,
112 peer.clone(),
113 ephemeral,
114 )?;
115
116 select! {
118 _ = context.sleep_until(deadline) => {
119 return Err(Error::HandshakeTimeout)
120 },
121 result = send_frame(&mut sink, &msg, config.max_message_size) => {
122 result.map_err(|_| Error::SendFailed)?;
123 },
124 }
125
126 let msg = select! {
128 _ = context.sleep_until(deadline) => {
129 return Err(Error::HandshakeTimeout)
130 },
131 result = recv_frame(&mut stream, config.max_message_size) => {
132 result.map_err(|_| Error::RecvFailed)?
133 },
134 };
135
136 let handshake = Handshake::verify(
138 &context,
139 &config.crypto,
140 &config.namespace,
141 config.synchrony_bound,
142 config.max_handshake_age,
143 msg,
144 )?;
145
146 if peer != handshake.peer_public_key {
148 return Err(Error::WrongPeer);
149 }
150
151 let shared_secret = secret.diffie_hellman(&handshake.ephemeral_public_key);
153 let cipher = ChaCha20Poly1305::new_from_slice(shared_secret.as_bytes())
154 .map_err(|_| Error::CipherCreationFailed)?;
155
156 Ok(Self {
158 dialer: true,
159 sink,
160 stream,
161 cipher,
162 max_message_size: config.max_message_size,
163 })
164 }
165
166 pub async fn upgrade_listener<R: Rng + CryptoRng + Spawner + Clock, C: Scheme>(
172 mut context: R,
173 incoming: IncomingConnection<C, Si, St>,
174 ) -> Result<Self, Error> {
175 let secret = x25519::new(&mut context);
177 let ephemeral = x25519_dalek::PublicKey::from(&secret);
178
179 let (mut handshake, mut config) = (incoming.handshake, incoming.config);
181 let timestamp = context.current().epoch_millis();
182 let msg = create_handshake(
183 &mut config.crypto,
184 &config.namespace,
185 timestamp,
186 handshake.peer_public_key,
187 ephemeral,
188 )?;
189
190 select! {
192 _ = context.sleep_until(handshake.deadline) => {
193 return Err(Error::HandshakeTimeout)
194 },
195 result = send_frame(&mut handshake.sink, &msg, config.max_message_size) => {
196 result.map_err(|_| Error::SendFailed)?;
197 },
198 }
199
200 let shared_secret = secret.diffie_hellman(&handshake.ephemeral_public_key);
202 let cipher = ChaCha20Poly1305::new_from_slice(shared_secret.as_bytes())
203 .map_err(|_| Error::CipherCreationFailed)?;
204
205 Ok(Connection {
207 dialer: false,
208 sink: handshake.sink,
209 stream: handshake.stream,
210 cipher,
211 max_message_size: config.max_message_size,
212 })
213 }
214
215 pub fn split(self) -> (Sender<Si>, Receiver<St>) {
220 (
221 Sender {
222 cipher: self.cipher.clone(),
223 sink: self.sink,
224 max_message_size: self.max_message_size,
225 nonce: nonce::Info::new(self.dialer),
226 },
227 Receiver {
228 cipher: self.cipher,
229 stream: self.stream,
230 max_message_size: self.max_message_size,
231 nonce: nonce::Info::new(!self.dialer),
232 },
233 )
234 }
235}
236
237pub struct Sender<Si: Sink> {
239 cipher: ChaCha20Poly1305,
240 sink: Si,
241
242 max_message_size: usize,
243 nonce: nonce::Info,
244}
245
246impl<Si: Sink> crate::Sender for Sender<Si> {
247 async fn send(&mut self, msg: &[u8]) -> Result<(), Error> {
248 let msg = self
250 .cipher
251 .encrypt(&self.nonce.encode(), msg.as_ref())
252 .map_err(|_| Error::EncryptionFailed)?;
253 self.nonce.inc()?;
254
255 send_frame(
257 &mut self.sink,
258 &msg,
259 self.max_message_size + ENCRYPTION_TAG_LENGTH,
260 )
261 .await?;
262 Ok(())
263 }
264}
265
266pub struct Receiver<St: Stream> {
268 cipher: ChaCha20Poly1305,
269 stream: St,
270
271 max_message_size: usize,
272 nonce: nonce::Info,
273}
274
275impl<St: Stream> crate::Receiver for Receiver<St> {
276 async fn receive(&mut self) -> Result<Bytes, Error> {
277 let msg = recv_frame(
279 &mut self.stream,
280 self.max_message_size + ENCRYPTION_TAG_LENGTH,
281 )
282 .await?;
283
284 let msg = self
286 .cipher
287 .decrypt(&self.nonce.encode(), msg.as_ref())
288 .map_err(|_| Error::DecryptionFailed)?;
289 self.nonce.inc()?;
290
291 Ok(Bytes::from(msg))
292 }
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298 use crate::{Receiver as _, Sender as _};
299 use commonware_runtime::{deterministic::Executor, mocks, Runner};
300
301 #[test]
302 fn test_decryption_failure() {
303 let (executor, _, _) = Executor::default();
304 executor.start(async move {
305 let cipher = ChaCha20Poly1305::new(&[0u8; 32].into());
306 let (mut sink, stream) = mocks::Channel::init();
307 let mut receiver = Receiver {
308 cipher,
309 stream,
310 max_message_size: 1024,
311 nonce: nonce::Info::new(false),
312 };
313
314 send_frame(&mut sink, b"invalid data", receiver.max_message_size)
316 .await
317 .unwrap();
318
319 let result = receiver.receive().await;
320 assert!(matches!(result, Err(Error::DecryptionFailed)));
321 });
322 }
323
324 #[test]
325 fn test_send_too_large() {
326 let (executor, _, _) = Executor::default();
327 executor.start(async move {
328 let cipher = ChaCha20Poly1305::new(&[0u8; 32].into());
329 let message = b"hello world";
330 let (sink, _) = mocks::Channel::init();
331 let mut sender = Sender {
332 cipher,
333 sink,
334 max_message_size: message.len() - 1,
335 nonce: nonce::Info::new(true),
336 };
337
338 let result = sender.send(message).await;
339 let expected_length = message.len() + ENCRYPTION_TAG_LENGTH;
340 assert!(matches!(result, Err(Error::SendTooLarge(n)) if n == expected_length));
341 });
342 }
343
344 #[test]
345 fn test_receive_too_large() {
346 let (executor, _, _) = Executor::default();
347 executor.start(async move {
348 let cipher = ChaCha20Poly1305::new(&[0u8; 32].into());
349 let message = b"hello world";
350 let (sink, stream) = mocks::Channel::init();
351
352 let mut sender = Sender {
353 cipher: cipher.clone(),
354 sink,
355 max_message_size: message.len(),
356 nonce: nonce::Info::new(true),
357 };
358 let mut receiver = Receiver {
359 cipher,
360 stream,
361 max_message_size: message.len() - 1,
362 nonce: nonce::Info::new(false),
363 };
364
365 sender.send(message).await.unwrap();
366 let result = receiver.receive().await;
367 let expected_length = message.len() + ENCRYPTION_TAG_LENGTH;
368 assert!(matches!(result, Err(Error::RecvTooLarge(n)) if n == expected_length));
369 });
370 }
371
372 #[test]
373 fn test_send_receive() {
374 let (executor, _, _) = Executor::default();
375 executor.start(async move {
376 let cipher = ChaCha20Poly1305::new(&[0u8; 32].into());
377 let message = b"hello world";
378 let max_message_size = message.len();
379
380 let (sink, stream) = mocks::Channel::init();
381 let is_dialer = false;
382 let mut sender = Sender {
383 cipher: cipher.clone(),
384 sink,
385 max_message_size,
386 nonce: nonce::Info::new(is_dialer),
387 };
388 let mut receiver = Receiver {
389 cipher,
390 stream,
391 max_message_size,
392 nonce: nonce::Info::new(is_dialer),
393 };
394
395 sender.send(message).await.unwrap();
397 let data = receiver.receive().await.unwrap();
398 assert_eq!(data, &message[..]);
399 });
400 }
401}