1use crate::utils::codec::{recv_frame, send_frame, send_frame_with};
59use commonware_codec::{DecodeExt, Encode as _, EncodeSize, Error as CodecError, Write};
60use commonware_cryptography::{
61 handshake::{
62 self, dial_end, dial_start, listen_end, listen_start, Ack, Context,
63 Error as HandshakeError, RecvCipher, SendCipher, Syn, SynAck,
64 },
65 transcript::Transcript,
66 Signer,
67};
68use commonware_macros::select;
69use commonware_runtime::{
70 BufMut, BufferPool, BufferPooler, Clock, Error as RuntimeError, IoBufs, Sink, Stream,
71};
72use commonware_utils::{hex, SystemTimeExt};
73use rand_core::CryptoRngCore;
74use std::{future::Future, ops::Range, time::Duration};
75use thiserror::Error;
76
77const TAG_SIZE: u32 = {
78 assert!(handshake::TAG_SIZE <= u32::MAX as usize);
79 handshake::TAG_SIZE as u32
80};
81
82#[derive(Error, Debug)]
84pub enum Error {
85 #[error("handshake error: {0}")]
86 HandshakeError(HandshakeError),
87 #[error("unable to decode: {0}")]
88 UnableToDecode(CodecError),
89 #[error("peer rejected: {}", hex(_0))]
90 PeerRejected(Vec<u8>),
91 #[error("recv failed")]
92 RecvFailed(RuntimeError),
93 #[error("recv too large: {0} bytes")]
94 RecvTooLarge(usize),
95 #[error("invalid varint length prefix")]
96 InvalidVarint,
97 #[error("send failed")]
98 SendFailed(RuntimeError),
99 #[error("send zero size")]
100 SendZeroSize,
101 #[error("send too large: {0} bytes")]
102 SendTooLarge(usize),
103 #[error("connection closed")]
104 StreamClosed,
105 #[error("handshake timed out")]
106 HandshakeTimeout,
107}
108
109impl From<CodecError> for Error {
110 fn from(value: CodecError) -> Self {
111 Self::UnableToDecode(value)
112 }
113}
114
115impl From<HandshakeError> for Error {
116 fn from(value: HandshakeError) -> Self {
117 Self::HandshakeError(value)
118 }
119}
120
121#[derive(Clone)]
128pub struct Config<S> {
129 pub signing_key: S,
133
134 pub namespace: Vec<u8>,
137
138 pub max_message_size: u32,
140
141 pub synchrony_bound: Duration,
143
144 pub max_handshake_age: Duration,
146
147 pub handshake_timeout: Duration,
149}
150
151impl<S> Config<S> {
152 pub fn time_information(&self, ctx: &impl Clock) -> (u64, Range<u64>) {
154 fn duration_to_u64(d: Duration) -> u64 {
155 u64::try_from(d.as_millis()).expect("duration ms should fit in an u64")
156 }
157 let current_time_ms = duration_to_u64(ctx.current().epoch());
158 let ok_timestamps = (current_time_ms
159 .saturating_sub(duration_to_u64(self.max_handshake_age)))
160 ..(current_time_ms.saturating_add(duration_to_u64(self.synchrony_bound)));
161 (current_time_ms, ok_timestamps)
162 }
163}
164
165pub async fn dial<R: BufferPooler + CryptoRngCore + Clock, S: Signer, I: Stream, O: Sink>(
168 mut ctx: R,
169 config: Config<S>,
170 peer: S::PublicKey,
171 mut stream: I,
172 mut sink: O,
173) -> Result<(Sender<O>, Receiver<I>), Error> {
174 let pool = ctx.network_buffer_pool().clone();
175 let timeout = ctx.sleep(config.handshake_timeout);
176 let inner_routine = async move {
177 send_frame(
178 &mut sink,
179 config.signing_key.public_key().encode(),
180 config.max_message_size,
181 )
182 .await?;
183
184 let (current_time, ok_timestamps) = config.time_information(&ctx);
185 let (state, syn) = dial_start(
186 &mut ctx,
187 Context::new(
188 &Transcript::new(&config.namespace),
189 current_time,
190 ok_timestamps,
191 config.signing_key,
192 peer,
193 ),
194 );
195 send_frame(&mut sink, syn.encode(), config.max_message_size).await?;
196
197 let syn_ack_bytes = recv_frame(&mut stream, config.max_message_size).await?;
198 let syn_ack = SynAck::<S::Signature>::decode(syn_ack_bytes)?;
199
200 let (ack, send, recv) = dial_end(state, syn_ack)?;
201 send_frame(&mut sink, ack.encode(), config.max_message_size).await?;
202
203 Ok((
204 Sender {
205 cipher: send,
206 sink,
207 max_message_size: config.max_message_size,
208 pool: pool.clone(),
209 },
210 Receiver {
211 cipher: recv,
212 stream,
213 max_message_size: config.max_message_size,
214 pool,
215 },
216 ))
217 };
218
219 select! {
220 x = inner_routine => x,
221 _ = timeout => Err(Error::HandshakeTimeout),
222 }
223}
224
225pub async fn listen<
228 R: BufferPooler + CryptoRngCore + Clock,
229 S: Signer,
230 I: Stream,
231 O: Sink,
232 Fut: Future<Output = bool>,
233 F: FnOnce(S::PublicKey) -> Fut,
234>(
235 mut ctx: R,
236 bouncer: F,
237 config: Config<S>,
238 mut stream: I,
239 mut sink: O,
240) -> Result<(S::PublicKey, Sender<O>, Receiver<I>), Error> {
241 let pool = ctx.network_buffer_pool().clone();
242 let timeout = ctx.sleep(config.handshake_timeout);
243 let inner_routine = async move {
244 let peer_bytes = recv_frame(&mut stream, config.max_message_size).await?;
245 let peer = S::PublicKey::decode(peer_bytes)?;
246 if !bouncer(peer.clone()).await {
247 return Err(Error::PeerRejected(peer.encode().to_vec()));
248 }
249
250 let msg1_bytes = recv_frame(&mut stream, config.max_message_size).await?;
251 let msg1 = Syn::<S::Signature>::decode(msg1_bytes)?;
252
253 let (current_time, ok_timestamps) = config.time_information(&ctx);
254 let (state, syn_ack) = listen_start(
255 &mut ctx,
256 Context::new(
257 &Transcript::new(&config.namespace),
258 current_time,
259 ok_timestamps,
260 config.signing_key,
261 peer.clone(),
262 ),
263 msg1,
264 )?;
265 send_frame(&mut sink, syn_ack.encode(), config.max_message_size).await?;
266
267 let ack_bytes = recv_frame(&mut stream, config.max_message_size).await?;
268 let ack = Ack::decode(ack_bytes)?;
269
270 let (send, recv) = listen_end(state, ack)?;
271
272 Ok((
273 peer,
274 Sender {
275 cipher: send,
276 sink,
277 max_message_size: config.max_message_size,
278 pool: pool.clone(),
279 },
280 Receiver {
281 cipher: recv,
282 stream,
283 max_message_size: config.max_message_size,
284 pool,
285 },
286 ))
287 };
288
289 select! {
290 x = inner_routine => x,
291 _ = timeout => Err(Error::HandshakeTimeout),
292 }
293}
294
295pub struct Sender<O> {
297 cipher: SendCipher,
298 sink: O,
299 max_message_size: u32,
300 pool: BufferPool,
301}
302
303impl<O: Sink> Sender<O> {
304 pub async fn send(&mut self, bufs: impl Into<IoBufs>) -> Result<(), Error> {
309 let mut bufs = bufs.into();
310 let ciphertext_len = bufs.len() + TAG_SIZE as usize;
311
312 send_frame_with(
313 &mut self.sink,
314 ciphertext_len,
315 self.max_message_size.saturating_add(TAG_SIZE),
316 |prefix| {
317 let prefix_len = prefix.encode_size();
318
319 let mut frame = self.pool.alloc(prefix_len + ciphertext_len);
321
322 prefix.write(&mut frame);
324
325 frame.put(&mut bufs);
327
328 let tag = self
330 .cipher
331 .send_in_place(&mut frame.as_mut()[prefix_len..])?;
332
333 frame.put_slice(&tag);
335
336 Ok(frame.freeze().into())
337 },
338 )
339 .await
340 }
341}
342
343pub struct Receiver<I> {
345 cipher: RecvCipher,
346 stream: I,
347 max_message_size: u32,
348 pool: BufferPool,
349}
350
351impl<I: Stream> Receiver<I> {
352 pub async fn recv(&mut self) -> Result<IoBufs, Error> {
357 let mut encrypted = recv_frame(
358 &mut self.stream,
359 self.max_message_size.saturating_add(TAG_SIZE),
360 )
361 .await?;
362 let ciphertext_len = encrypted.len();
363
364 let mut decryption_buf = self.pool.alloc(ciphertext_len);
366
367 decryption_buf.put(&mut encrypted);
369
370 let plaintext_len = self.cipher.recv_in_place(decryption_buf.as_mut())?;
372
373 decryption_buf.truncate(plaintext_len);
375
376 Ok(decryption_buf.freeze().into())
377 }
378}
379
380#[cfg(test)]
381mod test {
382 use super::*;
383 use commonware_cryptography::{ed25519::PrivateKey, Signer};
384 use commonware_runtime::{deterministic, mocks, Runner as _, Spawner as _};
385
386 const NAMESPACE: &[u8] = b"fuzz_transport";
387 const MAX_MESSAGE_SIZE: u32 = 64 * 1024; #[test]
390 fn test_can_setup_and_send_messages() -> Result<(), Error> {
391 let executor = deterministic::Runner::default();
392 executor.start(|context| async move {
393 let dialer_crypto = PrivateKey::from_seed(42);
394 let listener_crypto = PrivateKey::from_seed(24);
395
396 let (dialer_sink, listener_stream) = mocks::Channel::init();
397 let (listener_sink, dialer_stream) = mocks::Channel::init();
398
399 let dialer_config = Config {
400 signing_key: dialer_crypto.clone(),
401 namespace: NAMESPACE.to_vec(),
402 max_message_size: MAX_MESSAGE_SIZE,
403 synchrony_bound: Duration::from_secs(1),
404 max_handshake_age: Duration::from_secs(1),
405 handshake_timeout: Duration::from_secs(1),
406 };
407
408 let listener_config = Config {
409 signing_key: listener_crypto.clone(),
410 namespace: NAMESPACE.to_vec(),
411 max_message_size: MAX_MESSAGE_SIZE,
412 synchrony_bound: Duration::from_secs(1),
413 max_handshake_age: Duration::from_secs(1),
414 handshake_timeout: Duration::from_secs(1),
415 };
416
417 let listener_handle = context.clone().spawn(move |context| async move {
418 listen(
419 context,
420 |_| async { true },
421 listener_config,
422 listener_stream,
423 listener_sink,
424 )
425 .await
426 });
427
428 let (mut dialer_sender, mut dialer_receiver) = dial(
429 context,
430 dialer_config,
431 listener_crypto.public_key(),
432 dialer_stream,
433 dialer_sink,
434 )
435 .await?;
436
437 let (listener_peer, mut listener_sender, mut listener_receiver) =
438 listener_handle.await.unwrap()?;
439 assert_eq!(listener_peer, dialer_crypto.public_key());
440 let messages: Vec<&'static [u8]> = vec![b"A", b"B", b"C"];
441 for msg in &messages {
442 dialer_sender.send(&msg[..]).await?;
443 let syn_ack = listener_receiver.recv().await?;
444 assert_eq!(syn_ack.coalesce(), *msg);
445 listener_sender.send(&msg[..]).await?;
446 let ack = dialer_receiver.recv().await?;
447 assert_eq!(ack.coalesce(), *msg);
448 }
449 Ok(())
450 })
451 }
452}