1use crate::{
38 transcript::{Summary, Transcript},
39 PublicKey, Signature, Signer, Verifier,
40};
41use commonware_codec::{Encode, FixedSize, Read, ReadExt, Write};
42use core::ops::Range;
43use rand_core::CryptoRngCore;
44
45mod error;
46pub use error::Error;
47
48mod key_exchange;
49use key_exchange::{EphemeralPublicKey, SecretKey};
50
51mod cipher;
52pub use cipher::{RecvCipher, SendCipher, TAG_SIZE};
53
54#[cfg(all(test, feature = "arbitrary"))]
55mod conformance;
56
57const NAMESPACE: &[u8] = b"_COMMONWARE_CRYPTOGRAPHY_HANDSHAKE";
58const LABEL_CIPHER_L2D: &[u8] = b"cipher_l2d";
59const LABEL_CIPHER_D2L: &[u8] = b"cipher_d2l";
60const LABEL_CONFIRMATION_L2D: &[u8] = b"confirmation_l2d";
61const LABEL_CONFIRMATION_D2L: &[u8] = b"confirmation_d2l";
62
63#[cfg_attr(test, derive(Debug, PartialEq))]
66pub struct Syn<S: Signature> {
67 time_ms: u64,
68 epk: EphemeralPublicKey,
69 sig: S,
70}
71
72impl<S: Signature> FixedSize for Syn<S> {
73 const SIZE: usize = u64::SIZE + EphemeralPublicKey::SIZE + S::SIZE;
74}
75
76impl<S: Signature + Write> Write for Syn<S> {
77 fn write(&self, buf: &mut impl bytes::BufMut) {
78 self.time_ms.write(buf);
79 self.epk.write(buf);
80 self.sig.write(buf);
81 }
82}
83
84impl<S: Signature + Read> Read for Syn<S> {
85 type Cfg = S::Cfg;
86
87 fn read_cfg(
88 buf: &mut impl bytes::Buf,
89 cfg: &Self::Cfg,
90 ) -> Result<Self, commonware_codec::Error> {
91 Ok(Self {
92 time_ms: ReadExt::read(buf)?,
93 epk: ReadExt::read(buf)?,
94 sig: Read::read_cfg(buf, cfg)?,
95 })
96 }
97}
98
99#[cfg(feature = "arbitrary")]
100impl<S: Signature> arbitrary::Arbitrary<'_> for Syn<S>
101where
102 S: for<'a> arbitrary::Arbitrary<'a>,
103{
104 fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
105 Ok(Self {
106 time_ms: u.arbitrary()?,
107 epk: u.arbitrary()?,
108 sig: u.arbitrary()?,
109 })
110 }
111}
112
113#[cfg_attr(test, derive(Debug, PartialEq))]
116pub struct SynAck<S: Signature> {
117 time_ms: u64,
118 epk: EphemeralPublicKey,
119 sig: S,
120 confirmation: Summary,
121}
122
123impl<S: Signature> FixedSize for SynAck<S> {
124 const SIZE: usize = u64::SIZE + EphemeralPublicKey::SIZE + S::SIZE + Summary::SIZE;
125}
126
127impl<S: Signature + Write> Write for SynAck<S> {
128 fn write(&self, buf: &mut impl bytes::BufMut) {
129 self.time_ms.write(buf);
130 self.epk.write(buf);
131 self.sig.write(buf);
132 self.confirmation.write(buf);
133 }
134}
135
136impl<S: Signature + Read> Read for SynAck<S> {
137 type Cfg = S::Cfg;
138
139 fn read_cfg(
140 buf: &mut impl bytes::Buf,
141 cfg: &Self::Cfg,
142 ) -> Result<Self, commonware_codec::Error> {
143 Ok(Self {
144 time_ms: ReadExt::read(buf)?,
145 epk: ReadExt::read(buf)?,
146 sig: Read::read_cfg(buf, cfg)?,
147 confirmation: ReadExt::read(buf)?,
148 })
149 }
150}
151
152#[cfg(feature = "arbitrary")]
153impl<S: Signature> arbitrary::Arbitrary<'_> for SynAck<S>
154where
155 S: for<'a> arbitrary::Arbitrary<'a>,
156{
157 fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
158 Ok(Self {
159 time_ms: u.arbitrary()?,
160 epk: u.arbitrary()?,
161 sig: u.arbitrary()?,
162 confirmation: u.arbitrary()?,
163 })
164 }
165}
166
167#[cfg_attr(test, derive(PartialEq))]
170#[cfg_attr(feature = "arbitrary", derive(Debug, arbitrary::Arbitrary))]
171pub struct Ack {
172 confirmation: Summary,
173}
174
175impl FixedSize for Ack {
176 const SIZE: usize = Summary::SIZE;
177}
178
179impl Write for Ack {
180 fn write(&self, buf: &mut impl bytes::BufMut) {
181 self.confirmation.write(buf);
182 }
183}
184
185impl Read for Ack {
186 type Cfg = ();
187
188 fn read_cfg(
189 buf: &mut impl bytes::Buf,
190 _cfg: &Self::Cfg,
191 ) -> Result<Self, commonware_codec::Error> {
192 Ok(Self {
193 confirmation: ReadExt::read(buf)?,
194 })
195 }
196}
197
198pub struct DialState<P> {
201 esk: SecretKey,
202 peer_identity: P,
203 transcript: Transcript,
204 ok_timestamps: Range<u64>,
205}
206
207pub struct ListenState {
210 confirmation: Summary,
211 send: SendCipher,
212 recv: RecvCipher,
213}
214
215pub struct Context<S, P> {
218 transcript: Transcript,
219 current_time: u64,
220 ok_timestamps: Range<u64>,
221 my_identity: S,
222 peer_identity: P,
223}
224
225impl<S, P> Context<S, P> {
226 pub fn new(
228 base: &Transcript,
229 current_time_ms: u64,
230 ok_timestamps: Range<u64>,
231 my_identity: S,
232 peer_identity: P,
233 ) -> Self {
234 Self {
235 transcript: base.fork(NAMESPACE),
236 current_time: current_time_ms,
237 ok_timestamps,
238 my_identity,
239 peer_identity,
240 }
241 }
242}
243
244pub fn dial_start<S: Signer, P: PublicKey>(
247 rng: impl CryptoRngCore,
248 ctx: Context<S, P>,
249) -> (DialState<P>, Syn<<S as Signer>::Signature>) {
250 let Context {
251 current_time,
252 ok_timestamps,
253 my_identity,
254 peer_identity,
255 mut transcript,
256 } = ctx;
257 let esk = SecretKey::new(rng);
258 let epk = esk.public();
259 let sig = transcript
260 .commit(current_time.encode())
261 .commit(peer_identity.encode())
262 .commit(epk.encode())
263 .sign(&my_identity);
264 transcript.commit(my_identity.public_key().encode());
265 (
266 DialState {
267 esk,
268 peer_identity,
269 transcript,
270 ok_timestamps,
271 },
272 Syn {
273 time_ms: current_time,
274 epk,
275 sig,
276 },
277 )
278}
279
280pub fn dial_end<P: PublicKey>(
283 state: DialState<P>,
284 msg: SynAck<<P as Verifier>::Signature>,
285) -> Result<(Ack, SendCipher, RecvCipher), Error> {
286 let DialState {
287 esk,
288 peer_identity,
289 mut transcript,
290 ok_timestamps,
291 } = state;
292 if !ok_timestamps.contains(&msg.time_ms) {
293 return Err(Error::InvalidTimestamp(msg.time_ms, ok_timestamps));
294 }
295 if !transcript
296 .commit(msg.time_ms.encode())
297 .commit(msg.epk.encode())
298 .verify(&peer_identity, &msg.sig)
299 {
300 return Err(Error::HandshakeFailed);
301 }
302 let Some(shared) = esk.exchange(&msg.epk) else {
303 return Err(Error::HandshakeFailed);
304 };
305 shared
306 .secret
307 .expose(|secret| transcript.commit(secret.as_ref()));
308 let recv = RecvCipher::new(transcript.noise(LABEL_CIPHER_L2D));
309 let send = SendCipher::new(transcript.noise(LABEL_CIPHER_D2L));
310 let confirmation_l2d = transcript.fork(LABEL_CONFIRMATION_L2D).summarize();
311 let confirmation_d2l = transcript.fork(LABEL_CONFIRMATION_D2L).summarize();
312 if msg.confirmation != confirmation_l2d {
313 return Err(Error::HandshakeFailed);
314 }
315
316 Ok((
317 Ack {
318 confirmation: confirmation_d2l,
319 },
320 send,
321 recv,
322 ))
323}
324
325pub fn listen_start<S: Signer, P: PublicKey>(
328 rng: &mut impl CryptoRngCore,
329 ctx: Context<S, P>,
330 msg: Syn<<P as Verifier>::Signature>,
331) -> Result<(ListenState, SynAck<<S as Signer>::Signature>), Error> {
332 let Context {
333 current_time,
334 my_identity,
335 peer_identity,
336 ok_timestamps,
337 mut transcript,
338 } = ctx;
339 if !ok_timestamps.contains(&msg.time_ms) {
340 return Err(Error::InvalidTimestamp(msg.time_ms, ok_timestamps));
341 }
342 if !transcript
343 .commit(msg.time_ms.encode())
344 .commit(my_identity.public_key().encode())
345 .commit(msg.epk.encode())
346 .verify(&peer_identity, &msg.sig)
347 {
348 return Err(Error::HandshakeFailed);
349 }
350 let esk = SecretKey::new(rng);
351 let epk = esk.public();
352 let sig = transcript
353 .commit(peer_identity.encode())
354 .commit(current_time.encode())
355 .commit(epk.encode())
356 .sign(&my_identity);
357 let Some(shared) = esk.exchange(&msg.epk) else {
358 return Err(Error::HandshakeFailed);
359 };
360 shared
361 .secret
362 .expose(|secret| transcript.commit(secret.as_ref()));
363 let send = SendCipher::new(transcript.noise(LABEL_CIPHER_L2D));
364 let recv = RecvCipher::new(transcript.noise(LABEL_CIPHER_D2L));
365 let confirmation_l2d = transcript.fork(LABEL_CONFIRMATION_L2D).summarize();
366 let confirmation_d2l = transcript.fork(LABEL_CONFIRMATION_D2L).summarize();
367
368 Ok((
369 ListenState {
370 confirmation: confirmation_d2l,
371 send,
372 recv,
373 },
374 SynAck {
375 time_ms: current_time,
376 epk,
377 sig,
378 confirmation: confirmation_l2d,
379 },
380 ))
381}
382
383pub fn listen_end(state: ListenState, msg: Ack) -> Result<(SendCipher, RecvCipher), Error> {
386 if msg.confirmation != state.confirmation {
387 return Err(Error::HandshakeFailed);
388 }
389 Ok((state.send, state.recv))
390}
391
392#[cfg(test)]
393mod test {
394 use super::*;
395 use crate::{ed25519::PrivateKey, transcript::Transcript, Signer};
396 use commonware_codec::{Codec, DecodeExt};
397 use commonware_math::algebra::Random;
398 use commonware_utils::test_rng;
399
400 fn test_encode_roundtrip<T: Codec<Cfg = ()> + PartialEq>(value: &T) {
401 assert!(value == &<T as DecodeExt<_>>::decode(value.encode()).unwrap());
402 }
403
404 #[test]
405 fn test_can_setup_and_send_messages() -> Result<(), Error> {
406 let mut rng = test_rng();
407 let dialer_crypto = PrivateKey::random(&mut rng);
408 let listener_crypto = PrivateKey::random(&mut rng);
409
410 let (d_state, msg1) = dial_start(
411 &mut rng,
412 Context::new(
413 &Transcript::new(b"test_namespace"),
414 0,
415 0..1,
416 dialer_crypto.clone(),
417 listener_crypto.public_key(),
418 ),
419 );
420 test_encode_roundtrip(&msg1);
421 let (l_state, msg2) = listen_start(
422 &mut rng,
423 Context::new(
424 &Transcript::new(b"test_namespace"),
425 0,
426 0..1,
427 listener_crypto,
428 dialer_crypto.public_key(),
429 ),
430 msg1,
431 )?;
432 test_encode_roundtrip(&msg2);
433 let (msg3, mut d_send, mut d_recv) = dial_end(d_state, msg2)?;
434 test_encode_roundtrip(&msg3);
435 let (mut l_send, mut l_recv) = listen_end(l_state, msg3)?;
436
437 let m1: &'static [u8] = b"message 1";
438
439 let c1 = d_send.send(m1)?;
440 let m1_prime = l_recv.recv(&c1)?;
441 assert_eq!(m1, &m1_prime);
442
443 let m2: &'static [u8] = b"message 2";
444 let c2 = l_send.send(m2)?;
445 let m2_prime = d_recv.recv(&c2)?;
446 assert_eq!(m2, &m2_prime);
447
448 Ok(())
449 }
450
451 #[test]
452 fn test_mismatched_namespace_fails() {
453 let mut rng = test_rng();
454 let dialer_crypto = PrivateKey::random(&mut rng);
455 let listener_crypto = PrivateKey::random(&mut rng);
456
457 let (_, msg1) = dial_start(
458 &mut rng,
459 Context::new(
460 &Transcript::new(b"namespace_a"),
461 0,
462 0..1,
463 dialer_crypto.clone(),
464 listener_crypto.public_key(),
465 ),
466 );
467
468 let result = listen_start(
469 &mut rng,
470 Context::new(
471 &Transcript::new(b"namespace_b"),
472 0,
473 0..1,
474 listener_crypto,
475 dialer_crypto.public_key(),
476 ),
477 msg1,
478 );
479
480 assert!(matches!(result, Err(Error::HandshakeFailed)));
481 }
482
483 #[cfg(feature = "arbitrary")]
484 mod conformance {
485 use super::*;
486 use commonware_codec::conformance::CodecConformance;
487
488 commonware_conformance::conformance_tests! {
489 CodecConformance<Syn<crate::ed25519::Signature>>,
490 CodecConformance<SynAck<crate::ed25519::Signature>>,
491 CodecConformance<Ack>,
492 }
493 }
494}