1use crate::{
36 transcript::{Summary, Transcript},
37 PublicKey, Signature, Signer, Verifier,
38};
39use commonware_codec::{Encode, FixedSize, Read, ReadExt, Write};
40use core::ops::Range;
41use rand_core::CryptoRngCore;
42
43mod error;
44pub use error::Error;
45
46mod key_exchange;
47use key_exchange::{EphemeralPublicKey, SecretKey};
48
49mod cipher;
50pub use cipher::{RecvCipher, SendCipher, CIPHERTEXT_OVERHEAD};
51
52#[cfg(all(test, feature = "arbitrary"))]
53mod conformance;
54
55const NAMESPACE: &[u8] = b"_COMMONWARE_CRYPTOGRAPHY_HANDSHAKE";
56const LABEL_CIPHER_L2D: &[u8] = b"cipher_l2d";
57const LABEL_CIPHER_D2L: &[u8] = b"cipher_d2l";
58const LABEL_CONFIRMATION_L2D: &[u8] = b"confirmation_l2d";
59const LABEL_CONFIRMATION_D2L: &[u8] = b"confirmation_d2l";
60
61#[cfg_attr(test, derive(Debug, PartialEq))]
64pub struct Syn<S: Signature> {
65 time_ms: u64,
66 epk: EphemeralPublicKey,
67 sig: S,
68}
69
70impl<S: Signature> FixedSize for Syn<S> {
71 const SIZE: usize = u64::SIZE + EphemeralPublicKey::SIZE + S::SIZE;
72}
73
74impl<S: Signature + Write> Write for Syn<S> {
75 fn write(&self, buf: &mut impl bytes::BufMut) {
76 self.time_ms.write(buf);
77 self.epk.write(buf);
78 self.sig.write(buf);
79 }
80}
81
82impl<S: Signature + Read> Read for Syn<S> {
83 type Cfg = S::Cfg;
84
85 fn read_cfg(
86 buf: &mut impl bytes::Buf,
87 cfg: &Self::Cfg,
88 ) -> Result<Self, commonware_codec::Error> {
89 Ok(Self {
90 time_ms: ReadExt::read(buf)?,
91 epk: ReadExt::read(buf)?,
92 sig: Read::read_cfg(buf, cfg)?,
93 })
94 }
95}
96
97#[cfg(feature = "arbitrary")]
98impl<S: Signature> arbitrary::Arbitrary<'_> for Syn<S>
99where
100 S: for<'a> arbitrary::Arbitrary<'a>,
101{
102 fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
103 Ok(Self {
104 time_ms: u.arbitrary()?,
105 epk: u.arbitrary()?,
106 sig: u.arbitrary()?,
107 })
108 }
109}
110
111#[cfg_attr(test, derive(Debug, PartialEq))]
114pub struct SynAck<S: Signature> {
115 time_ms: u64,
116 epk: EphemeralPublicKey,
117 sig: S,
118 confirmation: Summary,
119}
120
121impl<S: Signature> FixedSize for SynAck<S> {
122 const SIZE: usize = u64::SIZE + EphemeralPublicKey::SIZE + S::SIZE + Summary::SIZE;
123}
124
125impl<S: Signature + Write> Write for SynAck<S> {
126 fn write(&self, buf: &mut impl bytes::BufMut) {
127 self.time_ms.write(buf);
128 self.epk.write(buf);
129 self.sig.write(buf);
130 self.confirmation.write(buf);
131 }
132}
133
134impl<S: Signature + Read> Read for SynAck<S> {
135 type Cfg = S::Cfg;
136
137 fn read_cfg(
138 buf: &mut impl bytes::Buf,
139 cfg: &Self::Cfg,
140 ) -> Result<Self, commonware_codec::Error> {
141 Ok(Self {
142 time_ms: ReadExt::read(buf)?,
143 epk: ReadExt::read(buf)?,
144 sig: Read::read_cfg(buf, cfg)?,
145 confirmation: ReadExt::read(buf)?,
146 })
147 }
148}
149
150#[cfg(feature = "arbitrary")]
151impl<S: Signature> arbitrary::Arbitrary<'_> for SynAck<S>
152where
153 S: for<'a> arbitrary::Arbitrary<'a>,
154{
155 fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
156 Ok(Self {
157 time_ms: u.arbitrary()?,
158 epk: u.arbitrary()?,
159 sig: u.arbitrary()?,
160 confirmation: u.arbitrary()?,
161 })
162 }
163}
164
165#[cfg_attr(test, derive(PartialEq))]
168#[cfg_attr(feature = "arbitrary", derive(Debug, arbitrary::Arbitrary))]
169pub struct Ack {
170 confirmation: Summary,
171}
172
173impl FixedSize for Ack {
174 const SIZE: usize = Summary::SIZE;
175}
176
177impl Write for Ack {
178 fn write(&self, buf: &mut impl bytes::BufMut) {
179 self.confirmation.write(buf);
180 }
181}
182
183impl Read for Ack {
184 type Cfg = ();
185
186 fn read_cfg(
187 buf: &mut impl bytes::Buf,
188 _cfg: &Self::Cfg,
189 ) -> Result<Self, commonware_codec::Error> {
190 Ok(Self {
191 confirmation: ReadExt::read(buf)?,
192 })
193 }
194}
195
196pub struct DialState<P> {
199 esk: SecretKey,
200 peer_identity: P,
201 transcript: Transcript,
202 ok_timestamps: Range<u64>,
203}
204
205pub struct ListenState {
208 confirmation: Summary,
209 send: SendCipher,
210 recv: RecvCipher,
211}
212
213pub struct Context<S, P> {
216 transcript: Transcript,
217 current_time: u64,
218 ok_timestamps: Range<u64>,
219 my_identity: S,
220 peer_identity: P,
221}
222
223impl<S, P> Context<S, P> {
224 pub fn new(
226 base: &Transcript,
227 current_time_ms: u64,
228 ok_timestamps: Range<u64>,
229 my_identity: S,
230 peer_identity: P,
231 ) -> Self {
232 Self {
233 transcript: base.fork(NAMESPACE),
234 current_time: current_time_ms,
235 ok_timestamps,
236 my_identity,
237 peer_identity,
238 }
239 }
240}
241
242pub fn dial_start<S: Signer, P: PublicKey>(
245 rng: impl CryptoRngCore,
246 ctx: Context<S, P>,
247) -> (DialState<P>, Syn<<S as Signer>::Signature>) {
248 let Context {
249 current_time,
250 ok_timestamps,
251 my_identity,
252 peer_identity,
253 mut transcript,
254 } = ctx;
255 let esk = SecretKey::new(rng);
256 let epk = esk.public();
257 let sig = transcript
258 .commit(current_time.encode())
259 .commit(peer_identity.encode())
260 .commit(epk.encode())
261 .sign(&my_identity);
262 transcript.commit(my_identity.public_key().encode());
263 (
264 DialState {
265 esk,
266 peer_identity,
267 transcript,
268 ok_timestamps,
269 },
270 Syn {
271 time_ms: current_time,
272 epk,
273 sig,
274 },
275 )
276}
277
278pub fn dial_end<P: PublicKey>(
281 state: DialState<P>,
282 msg: SynAck<<P as Verifier>::Signature>,
283) -> Result<(Ack, SendCipher, RecvCipher), Error> {
284 let DialState {
285 esk,
286 peer_identity,
287 mut transcript,
288 ok_timestamps,
289 } = state;
290 if !ok_timestamps.contains(&msg.time_ms) {
291 return Err(Error::InvalidTimestamp(msg.time_ms, ok_timestamps));
292 }
293 if !transcript
294 .commit(msg.time_ms.encode())
295 .commit(msg.epk.encode())
296 .verify(&peer_identity, &msg.sig)
297 {
298 return Err(Error::HandshakeFailed);
299 }
300 let Some(shared) = esk.exchange(&msg.epk) else {
301 return Err(Error::HandshakeFailed);
302 };
303 shared
304 .secret
305 .expose(|secret| transcript.commit(secret.as_ref()));
306 let recv = RecvCipher::new(transcript.noise(LABEL_CIPHER_L2D));
307 let send = SendCipher::new(transcript.noise(LABEL_CIPHER_D2L));
308 let confirmation_l2d = transcript.fork(LABEL_CONFIRMATION_L2D).summarize();
309 let confirmation_d2l = transcript.fork(LABEL_CONFIRMATION_D2L).summarize();
310 if msg.confirmation != confirmation_l2d {
311 return Err(Error::HandshakeFailed);
312 }
313
314 Ok((
315 Ack {
316 confirmation: confirmation_d2l,
317 },
318 send,
319 recv,
320 ))
321}
322
323pub fn listen_start<S: Signer, P: PublicKey>(
326 rng: &mut impl CryptoRngCore,
327 ctx: Context<S, P>,
328 msg: Syn<<P as Verifier>::Signature>,
329) -> Result<(ListenState, SynAck<<S as Signer>::Signature>), Error> {
330 let Context {
331 current_time,
332 my_identity,
333 peer_identity,
334 ok_timestamps,
335 mut transcript,
336 } = ctx;
337 if !ok_timestamps.contains(&msg.time_ms) {
338 return Err(Error::InvalidTimestamp(msg.time_ms, ok_timestamps));
339 }
340 if !transcript
341 .commit(msg.time_ms.encode())
342 .commit(my_identity.public_key().encode())
343 .commit(msg.epk.encode())
344 .verify(&peer_identity, &msg.sig)
345 {
346 return Err(Error::HandshakeFailed);
347 }
348 let esk = SecretKey::new(rng);
349 let epk = esk.public();
350 let sig = transcript
351 .commit(peer_identity.encode())
352 .commit(current_time.encode())
353 .commit(epk.encode())
354 .sign(&my_identity);
355 let Some(shared) = esk.exchange(&msg.epk) else {
356 return Err(Error::HandshakeFailed);
357 };
358 shared
359 .secret
360 .expose(|secret| transcript.commit(secret.as_ref()));
361 let send = SendCipher::new(transcript.noise(LABEL_CIPHER_L2D));
362 let recv = RecvCipher::new(transcript.noise(LABEL_CIPHER_D2L));
363 let confirmation_l2d = transcript.fork(LABEL_CONFIRMATION_L2D).summarize();
364 let confirmation_d2l = transcript.fork(LABEL_CONFIRMATION_D2L).summarize();
365
366 Ok((
367 ListenState {
368 confirmation: confirmation_d2l,
369 send,
370 recv,
371 },
372 SynAck {
373 time_ms: current_time,
374 epk,
375 sig,
376 confirmation: confirmation_l2d,
377 },
378 ))
379}
380
381pub fn listen_end(state: ListenState, msg: Ack) -> Result<(SendCipher, RecvCipher), Error> {
384 if msg.confirmation != state.confirmation {
385 return Err(Error::HandshakeFailed);
386 }
387 Ok((state.send, state.recv))
388}
389
390#[cfg(test)]
391mod test {
392 use super::*;
393 use crate::{ed25519::PrivateKey, transcript::Transcript, Signer};
394 use commonware_codec::{Codec, DecodeExt};
395 use commonware_math::algebra::Random;
396 use commonware_utils::test_rng;
397
398 fn test_encode_roundtrip<T: Codec<Cfg = ()> + PartialEq>(value: &T) {
399 assert!(value == &<T as DecodeExt<_>>::decode(value.encode()).unwrap());
400 }
401
402 #[test]
403 fn test_can_setup_and_send_messages() -> Result<(), Error> {
404 let mut rng = test_rng();
405 let dialer_crypto = PrivateKey::random(&mut rng);
406 let listener_crypto = PrivateKey::random(&mut rng);
407
408 let (d_state, msg1) = dial_start(
409 &mut rng,
410 Context::new(
411 &Transcript::new(b"test_namespace"),
412 0,
413 0..1,
414 dialer_crypto.clone(),
415 listener_crypto.public_key(),
416 ),
417 );
418 test_encode_roundtrip(&msg1);
419 let (l_state, msg2) = listen_start(
420 &mut rng,
421 Context::new(
422 &Transcript::new(b"test_namespace"),
423 0,
424 0..1,
425 listener_crypto,
426 dialer_crypto.public_key(),
427 ),
428 msg1,
429 )?;
430 test_encode_roundtrip(&msg2);
431 let (msg3, mut d_send, mut d_recv) = dial_end(d_state, msg2)?;
432 test_encode_roundtrip(&msg3);
433 let (mut l_send, mut l_recv) = listen_end(l_state, msg3)?;
434
435 let m1: &'static [u8] = b"message 1";
436
437 let c1 = d_send.send(m1)?;
438 let m1_prime = l_recv.recv(&c1)?;
439 assert_eq!(m1, &m1_prime);
440
441 let m2: &'static [u8] = b"message 2";
442 let c2 = l_send.send(m2)?;
443 let m2_prime = d_recv.recv(&c2)?;
444 assert_eq!(m2, &m2_prime);
445
446 Ok(())
447 }
448
449 #[test]
450 fn test_mismatched_namespace_fails() {
451 let mut rng = test_rng();
452 let dialer_crypto = PrivateKey::random(&mut rng);
453 let listener_crypto = PrivateKey::random(&mut rng);
454
455 let (_, msg1) = dial_start(
456 &mut rng,
457 Context::new(
458 &Transcript::new(b"namespace_a"),
459 0,
460 0..1,
461 dialer_crypto.clone(),
462 listener_crypto.public_key(),
463 ),
464 );
465
466 let result = listen_start(
467 &mut rng,
468 Context::new(
469 &Transcript::new(b"namespace_b"),
470 0,
471 0..1,
472 listener_crypto,
473 dialer_crypto.public_key(),
474 ),
475 msg1,
476 );
477
478 assert!(matches!(result, Err(Error::HandshakeFailed)));
479 }
480
481 #[cfg(feature = "arbitrary")]
482 mod conformance {
483 use super::*;
484 use commonware_codec::conformance::CodecConformance;
485
486 commonware_conformance::conformance_tests! {
487 CodecConformance<Syn<crate::ed25519::Signature>>,
488 CodecConformance<SynAck<crate::ed25519::Signature>>,
489 CodecConformance<Ack>,
490 }
491 }
492}