1#![doc(
59 html_logo_url = "https://commonware.xyz/imgs/rustdoc_logo.svg",
60 html_favicon_url = "https://commonware.xyz/favicon.ico"
61)]
62
63pub mod utils;
64
65use crate::utils::codec::{recv_frame, send_frame};
66use bytes::Bytes;
67use commonware_codec::{DecodeExt, Encode as _, Error as CodecError};
68use commonware_cryptography::{
69 handshake::{
70 dial_end, dial_start, listen_end, listen_start, Ack, Context, Error as HandshakeError,
71 RecvCipher, SendCipher, Syn, SynAck, CIPHERTEXT_OVERHEAD,
72 },
73 Signer,
74};
75use commonware_macros::select;
76use commonware_runtime::{Clock, Error as RuntimeError, Sink, Stream};
77use commonware_utils::{hex, SystemTimeExt};
78use rand_core::CryptoRngCore;
79use std::{future::Future, ops::Range, time::Duration};
80use thiserror::Error;
81
82#[derive(Error, Debug)]
84pub enum Error {
85 #[error("handshake error: {0}")]
86 HandshakeError(HandshakeError),
87 #[error("unable to decode: {0}")]
88 UnableToDecode(CodecError),
89 #[error("peer rejected: {}", hex(_0))]
90 PeerRejected(Vec<u8>),
91 #[error("recv failed")]
92 RecvFailed(RuntimeError),
93 #[error("recv too large: {0} bytes")]
94 RecvTooLarge(usize),
95 #[error("send failed")]
96 SendFailed(RuntimeError),
97 #[error("send zero size")]
98 SendZeroSize,
99 #[error("send too large: {0} bytes")]
100 SendTooLarge(usize),
101 #[error("connection closed")]
102 StreamClosed,
103 #[error("handshake timed out")]
104 HandshakeTimeout,
105}
106
107impl From<CodecError> for Error {
108 fn from(value: CodecError) -> Self {
109 Self::UnableToDecode(value)
110 }
111}
112
113impl From<HandshakeError> for Error {
114 fn from(value: HandshakeError) -> Self {
115 Self::HandshakeError(value)
116 }
117}
118
119#[derive(Clone)]
126pub struct Config<S> {
127 pub signing_key: S,
131
132 pub namespace: Vec<u8>,
135
136 pub max_message_size: usize,
138
139 pub synchrony_bound: Duration,
141
142 pub max_handshake_age: Duration,
144
145 pub handshake_timeout: Duration,
147}
148
149impl<S> Config<S> {
150 pub fn time_information(&self, ctx: &impl Clock) -> (u64, Range<u64>) {
152 fn duration_to_u64(d: Duration) -> u64 {
153 u64::try_from(d.as_millis()).expect("duration ms should fit in an u64")
154 }
155 let current_time_ms = duration_to_u64(ctx.current().epoch());
156 let ok_timestamps = (current_time_ms
157 .saturating_sub(duration_to_u64(self.max_handshake_age)))
158 ..(current_time_ms.saturating_add(duration_to_u64(self.synchrony_bound)));
159 (current_time_ms, ok_timestamps)
160 }
161}
162
163pub async fn dial<R: CryptoRngCore + Clock, S: Signer, I: Stream, O: Sink>(
166 mut ctx: R,
167 config: Config<S>,
168 peer: S::PublicKey,
169 mut stream: I,
170 mut sink: O,
171) -> Result<(Sender<O>, Receiver<I>), Error> {
172 let timeout = ctx.sleep(config.handshake_timeout);
173 let inner_routine = async move {
174 send_frame(
175 &mut sink,
176 config.signing_key.public_key().encode().as_ref(),
177 config.max_message_size,
178 )
179 .await?;
180
181 let (current_time, ok_timestamps) = config.time_information(&ctx);
182 let (state, syn) = dial_start(
183 &mut ctx,
184 Context::new(current_time, ok_timestamps, config.signing_key, peer),
185 );
186 send_frame(&mut sink, &syn.encode(), config.max_message_size).await?;
187
188 let syn_ack_bytes = recv_frame(&mut stream, config.max_message_size).await?;
189 let syn_ack = SynAck::<S::Signature>::decode(syn_ack_bytes)?;
190
191 let (ack, send, recv) = dial_end(state, syn_ack)?;
192 send_frame(&mut sink, &ack.encode(), config.max_message_size).await?;
193
194 Ok((
195 Sender {
196 cipher: send,
197 sink,
198 max_message_size: config.max_message_size,
199 },
200 Receiver {
201 cipher: recv,
202 stream,
203 max_message_size: config.max_message_size,
204 },
205 ))
206 };
207
208 select! {
209 x = inner_routine => { x } ,
210 _ = timeout => { Err(Error::HandshakeTimeout) }
211 }
212}
213
214pub async fn listen<
217 R: CryptoRngCore + Clock,
218 S: Signer,
219 I: Stream,
220 O: Sink,
221 Fut: Future<Output = bool>,
222 F: FnOnce(S::PublicKey) -> Fut,
223>(
224 mut ctx: R,
225 bouncer: F,
226 config: Config<S>,
227 mut stream: I,
228 mut sink: O,
229) -> Result<(S::PublicKey, Sender<O>, Receiver<I>), Error> {
230 let timeout = ctx.sleep(config.handshake_timeout);
231 let inner_routine = async move {
232 let peer_bytes = recv_frame(&mut stream, config.max_message_size).await?;
233 let peer = S::PublicKey::decode(peer_bytes)?;
234 if !bouncer(peer.clone()).await {
235 return Err(Error::PeerRejected(peer.encode().to_vec()));
236 }
237
238 let msg1_bytes = recv_frame(&mut stream, config.max_message_size).await?;
239 let msg1 = Syn::<S::Signature>::decode(msg1_bytes)?;
240
241 let (current_time, ok_timestamps) = config.time_information(&ctx);
242 let (state, syn_ack) = listen_start(
243 &mut ctx,
244 Context::new(
245 current_time,
246 ok_timestamps,
247 config.signing_key,
248 peer.clone(),
249 ),
250 msg1,
251 )?;
252 send_frame(&mut sink, &syn_ack.encode(), config.max_message_size).await?;
253
254 let ack_bytes = recv_frame(&mut stream, config.max_message_size).await?;
255 let ack = Ack::decode(ack_bytes)?;
256
257 let (send, recv) = listen_end(state, ack)?;
258
259 Ok((
260 peer,
261 Sender {
262 cipher: send,
263 sink,
264 max_message_size: config.max_message_size,
265 },
266 Receiver {
267 cipher: recv,
268 stream,
269 max_message_size: config.max_message_size,
270 },
271 ))
272 };
273
274 select! {
275 x = inner_routine => { x } ,
276 _ = timeout => { Err(Error::HandshakeTimeout) }
277 }
278}
279
280pub struct Sender<O> {
282 cipher: SendCipher,
283 sink: O,
284 max_message_size: usize,
285}
286
287impl<O: Sink> Sender<O> {
288 pub async fn send(&mut self, msg: &[u8]) -> Result<(), Error> {
290 let c = self.cipher.send(msg)?;
291 send_frame(
292 &mut self.sink,
293 &c,
294 self.max_message_size + CIPHERTEXT_OVERHEAD,
295 )
296 .await?;
297 Ok(())
298 }
299}
300
301pub struct Receiver<I> {
303 cipher: RecvCipher,
304 stream: I,
305 max_message_size: usize,
306}
307
308impl<I: Stream> Receiver<I> {
309 pub async fn recv(&mut self) -> Result<Bytes, Error> {
311 let c = recv_frame(
312 &mut self.stream,
313 self.max_message_size + CIPHERTEXT_OVERHEAD,
314 )
315 .await?;
316 Ok(self.cipher.recv(&c)?.into())
317 }
318}
319
320#[cfg(test)]
321mod test {
322 use super::*;
323 use commonware_cryptography::{ed25519::PrivateKey, PrivateKeyExt as _, Signer};
324 use commonware_runtime::{deterministic, mocks, Runner as _, Spawner as _};
325
326 const NAMESPACE: &[u8] = b"fuzz_transport";
327 const MAX_MESSAGE_SIZE: usize = 64 * 1024; #[test]
330 fn test_can_setup_and_send_messages() -> Result<(), Error> {
331 let executor = deterministic::Runner::default();
332 executor.start(|context| async move {
333 let dialer_crypto = PrivateKey::from_seed(42);
334 let listener_crypto = PrivateKey::from_seed(24);
335
336 let (dialer_sink, listener_stream) = mocks::Channel::init();
337 let (listener_sink, dialer_stream) = mocks::Channel::init();
338
339 let dialer_config = Config {
340 signing_key: dialer_crypto.clone(),
341 namespace: NAMESPACE.to_vec(),
342 max_message_size: MAX_MESSAGE_SIZE,
343 synchrony_bound: Duration::from_secs(1),
344 max_handshake_age: Duration::from_secs(1),
345 handshake_timeout: Duration::from_secs(1),
346 };
347
348 let listener_config = Config {
349 signing_key: listener_crypto.clone(),
350 namespace: NAMESPACE.to_vec(),
351 max_message_size: MAX_MESSAGE_SIZE,
352 synchrony_bound: Duration::from_secs(1),
353 max_handshake_age: Duration::from_secs(1),
354 handshake_timeout: Duration::from_secs(1),
355 };
356
357 let listener_handle = context.clone().spawn(move |context| async move {
358 listen(
359 context,
360 |_| async { true },
361 listener_config,
362 listener_stream,
363 listener_sink,
364 )
365 .await
366 });
367
368 let (mut dialer_sender, mut dialer_receiver) = dial(
369 context,
370 dialer_config,
371 listener_crypto.public_key(),
372 dialer_stream,
373 dialer_sink,
374 )
375 .await?;
376
377 let (listener_peer, mut listener_sender, mut listener_receiver) =
378 listener_handle.await.unwrap()?;
379 assert_eq!(listener_peer, dialer_crypto.public_key());
380 let messages: Vec<&'static [u8]> = vec![b"A", b"B", b"C"];
381 for msg in &messages {
382 dialer_sender.send(msg).await?;
383 let syn_ack = listener_receiver.recv().await?;
384 assert_eq!(msg, &syn_ack);
385 listener_sender.send(msg).await?;
386 let ack = dialer_receiver.recv().await?;
387 assert_eq!(msg, &ack);
388 }
389 Ok(())
390 })
391 }
392}