1use super::{
2    handshake::{create_handshake, Handshake, IncomingHandshake},
3    nonce, x25519, Config,
4};
5use crate::{
6    utils::codec::{recv_frame, send_frame},
7    Error,
8};
9use bytes::Bytes;
10use chacha20poly1305::{
11    aead::{Aead, KeyInit},
12    ChaCha20Poly1305,
13};
14use commonware_cryptography::Scheme;
15use commonware_macros::select;
16use commonware_runtime::{Clock, Sink, Spawner, Stream};
17use commonware_utils::SystemTimeExt as _;
18use rand::{CryptoRng, Rng};
19
20const ENCRYPTION_TAG_LENGTH: usize = 16;
23
24pub struct IncomingConnection<C: Scheme, Si: Sink, St: Stream> {
26    config: Config<C>,
27    handshake: IncomingHandshake<Si, St, C>,
28}
29
30impl<C: Scheme, Si: Sink, St: Stream> IncomingConnection<C, Si, St> {
31    pub async fn verify<R: Rng + CryptoRng + Spawner + Clock>(
33        context: &R,
34        config: Config<C>,
35        sink: Si,
36        stream: St,
37    ) -> Result<Self, Error> {
38        let handshake = IncomingHandshake::verify(
39            context,
40            &config.crypto,
41            &config.namespace,
42            config.max_message_size,
43            config.synchrony_bound,
44            config.max_handshake_age,
45            config.handshake_timeout,
46            sink,
47            stream,
48        )
49        .await?;
50        Ok(Self { config, handshake })
51    }
52
53    pub fn peer(&self) -> C::PublicKey {
55        self.handshake.peer_public_key.clone()
56    }
57}
58
59pub struct Connection<Si: Sink, St: Stream> {
61    dialer: bool,
62    sink: Si,
63    stream: St,
64    cipher: ChaCha20Poly1305,
65    max_message_size: usize,
66}
67
68impl<Si: Sink, St: Stream> Connection<Si, St> {
69    pub fn from_preestablished(
73        dialer: bool,
74        sink: Si,
75        stream: St,
76        cipher: ChaCha20Poly1305,
77        max_message_size: usize,
78    ) -> Self {
79        Self {
80            dialer,
81            sink,
82            stream,
83            cipher,
84            max_message_size,
85        }
86    }
87
88    pub async fn upgrade_dialer<R: Rng + CryptoRng + Spawner + Clock, C: Scheme>(
93        mut context: R,
94        mut config: Config<C>,
95        mut sink: Si,
96        mut stream: St,
97        peer: C::PublicKey,
98    ) -> Result<Self, Error> {
99        let deadline = context.current() + config.handshake_timeout;
101
102        let secret = x25519::new(&mut context);
104        let ephemeral = x25519_dalek::PublicKey::from(&secret);
105
106        let timestamp = context.current().epoch_millis();
108        let msg = create_handshake(
109            &mut config.crypto,
110            &config.namespace,
111            timestamp,
112            peer.clone(),
113            ephemeral,
114        )?;
115
116        select! {
118            _ = context.sleep_until(deadline) => {
119                return Err(Error::HandshakeTimeout)
120            },
121            result = send_frame(&mut sink, &msg, config.max_message_size) => {
122                result.map_err(|_| Error::SendFailed)?;
123            },
124        }
125
126        let msg = select! {
128            _ = context.sleep_until(deadline) => {
129                return Err(Error::HandshakeTimeout)
130            },
131            result = recv_frame(&mut stream, config.max_message_size) => {
132                result.map_err(|_| Error::RecvFailed)?
133            },
134        };
135
136        let handshake = Handshake::verify(
138            &context,
139            &config.crypto,
140            &config.namespace,
141            config.synchrony_bound,
142            config.max_handshake_age,
143            msg,
144        )?;
145
146        if peer != handshake.peer_public_key {
148            return Err(Error::WrongPeer);
149        }
150
151        let shared_secret = secret.diffie_hellman(&handshake.ephemeral_public_key);
153        let cipher = ChaCha20Poly1305::new_from_slice(shared_secret.as_bytes())
154            .map_err(|_| Error::CipherCreationFailed)?;
155
156        Ok(Self {
158            dialer: true,
159            sink,
160            stream,
161            cipher,
162            max_message_size: config.max_message_size,
163        })
164    }
165
166    pub async fn upgrade_listener<R: Rng + CryptoRng + Spawner + Clock, C: Scheme>(
172        mut context: R,
173        incoming: IncomingConnection<C, Si, St>,
174    ) -> Result<Self, Error> {
175        let secret = x25519::new(&mut context);
177        let ephemeral = x25519_dalek::PublicKey::from(&secret);
178
179        let (mut handshake, mut config) = (incoming.handshake, incoming.config);
181        let timestamp = context.current().epoch_millis();
182        let msg = create_handshake(
183            &mut config.crypto,
184            &config.namespace,
185            timestamp,
186            handshake.peer_public_key,
187            ephemeral,
188        )?;
189
190        select! {
192            _ = context.sleep_until(handshake.deadline) => {
193                return Err(Error::HandshakeTimeout)
194            },
195            result = send_frame(&mut handshake.sink, &msg, config.max_message_size) => {
196                result.map_err(|_| Error::SendFailed)?;
197            },
198        }
199
200        let shared_secret = secret.diffie_hellman(&handshake.ephemeral_public_key);
202        let cipher = ChaCha20Poly1305::new_from_slice(shared_secret.as_bytes())
203            .map_err(|_| Error::CipherCreationFailed)?;
204
205        Ok(Connection {
207            dialer: false,
208            sink: handshake.sink,
209            stream: handshake.stream,
210            cipher,
211            max_message_size: config.max_message_size,
212        })
213    }
214
215    pub fn split(self) -> (Sender<Si>, Receiver<St>) {
220        (
221            Sender {
222                cipher: self.cipher.clone(),
223                sink: self.sink,
224                max_message_size: self.max_message_size,
225                nonce: nonce::Info::new(self.dialer),
226            },
227            Receiver {
228                cipher: self.cipher,
229                stream: self.stream,
230                max_message_size: self.max_message_size,
231                nonce: nonce::Info::new(!self.dialer),
232            },
233        )
234    }
235}
236
237pub struct Sender<Si: Sink> {
239    cipher: ChaCha20Poly1305,
240    sink: Si,
241
242    max_message_size: usize,
243    nonce: nonce::Info,
244}
245
246impl<Si: Sink> crate::Sender for Sender<Si> {
247    async fn send(&mut self, msg: &[u8]) -> Result<(), Error> {
248        let msg = self
250            .cipher
251            .encrypt(&self.nonce.encode(), msg.as_ref())
252            .map_err(|_| Error::EncryptionFailed)?;
253        self.nonce.inc()?;
254
255        send_frame(
257            &mut self.sink,
258            &msg,
259            self.max_message_size + ENCRYPTION_TAG_LENGTH,
260        )
261        .await?;
262        Ok(())
263    }
264}
265
266pub struct Receiver<St: Stream> {
268    cipher: ChaCha20Poly1305,
269    stream: St,
270
271    max_message_size: usize,
272    nonce: nonce::Info,
273}
274
275impl<St: Stream> crate::Receiver for Receiver<St> {
276    async fn receive(&mut self) -> Result<Bytes, Error> {
277        let msg = recv_frame(
279            &mut self.stream,
280            self.max_message_size + ENCRYPTION_TAG_LENGTH,
281        )
282        .await?;
283
284        let msg = self
286            .cipher
287            .decrypt(&self.nonce.encode(), msg.as_ref())
288            .map_err(|_| Error::DecryptionFailed)?;
289        self.nonce.inc()?;
290
291        Ok(Bytes::from(msg))
292    }
293}
294
295#[cfg(test)]
296mod tests {
297    use super::*;
298    use crate::{Receiver as _, Sender as _};
299    use commonware_runtime::{deterministic::Executor, mocks, Runner};
300
301    #[test]
302    fn test_decryption_failure() {
303        let (executor, _, _) = Executor::default();
304        executor.start(async move {
305            let cipher = ChaCha20Poly1305::new(&[0u8; 32].into());
306            let (mut sink, stream) = mocks::Channel::init();
307            let mut receiver = Receiver {
308                cipher,
309                stream,
310                max_message_size: 1024,
311                nonce: nonce::Info::new(false),
312            };
313
314            send_frame(&mut sink, b"invalid data", receiver.max_message_size)
316                .await
317                .unwrap();
318
319            let result = receiver.receive().await;
320            assert!(matches!(result, Err(Error::DecryptionFailed)));
321        });
322    }
323
324    #[test]
325    fn test_send_too_large() {
326        let (executor, _, _) = Executor::default();
327        executor.start(async move {
328            let cipher = ChaCha20Poly1305::new(&[0u8; 32].into());
329            let message = b"hello world";
330            let (sink, _) = mocks::Channel::init();
331            let mut sender = Sender {
332                cipher,
333                sink,
334                max_message_size: message.len() - 1,
335                nonce: nonce::Info::new(true),
336            };
337
338            let result = sender.send(message).await;
339            let expected_length = message.len() + ENCRYPTION_TAG_LENGTH;
340            assert!(matches!(result, Err(Error::SendTooLarge(n)) if n == expected_length));
341        });
342    }
343
344    #[test]
345    fn test_receive_too_large() {
346        let (executor, _, _) = Executor::default();
347        executor.start(async move {
348            let cipher = ChaCha20Poly1305::new(&[0u8; 32].into());
349            let message = b"hello world";
350            let (sink, stream) = mocks::Channel::init();
351
352            let mut sender = Sender {
353                cipher: cipher.clone(),
354                sink,
355                max_message_size: message.len(),
356                nonce: nonce::Info::new(true),
357            };
358            let mut receiver = Receiver {
359                cipher,
360                stream,
361                max_message_size: message.len() - 1,
362                nonce: nonce::Info::new(false),
363            };
364
365            sender.send(message).await.unwrap();
366            let result = receiver.receive().await;
367            let expected_length = message.len() + ENCRYPTION_TAG_LENGTH;
368            assert!(matches!(result, Err(Error::RecvTooLarge(n)) if n == expected_length));
369        });
370    }
371
372    #[test]
373    fn test_send_receive() {
374        let (executor, _, _) = Executor::default();
375        executor.start(async move {
376            let cipher = ChaCha20Poly1305::new(&[0u8; 32].into());
377            let message = b"hello world";
378            let max_message_size = message.len();
379
380            let (sink, stream) = mocks::Channel::init();
381            let is_dialer = false;
382            let mut sender = Sender {
383                cipher: cipher.clone(),
384                sink,
385                max_message_size,
386                nonce: nonce::Info::new(is_dialer),
387            };
388            let mut receiver = Receiver {
389                cipher,
390                stream,
391                max_message_size,
392                nonce: nonce::Info::new(is_dialer),
393            };
394
395            sender.send(message).await.unwrap();
397            let data = receiver.receive().await.unwrap();
398            assert_eq!(data, &message[..]);
399        });
400    }
401}