1use crate::{
36 transcript::{Summary, Transcript},
37 PublicKey, Signature, Signer, Verifier,
38};
39use commonware_codec::{Encode, FixedSize, Read, ReadExt, Write};
40use core::ops::Range;
41use rand_core::CryptoRngCore;
42
43mod error;
44pub use error::Error;
45
46mod key_exchange;
47use key_exchange::{EphemeralPublicKey, SecretKey};
48
49mod cipher;
50pub use cipher::{RecvCipher, SendCipher, CIPHERTEXT_OVERHEAD};
51
52#[cfg(all(test, feature = "arbitrary"))]
53mod conformance;
54
55const NAMESPACE: &[u8] = b"_COMMONWARE_CRYPTOGRAPHY_HANDSHAKE";
56const LABEL_CIPHER_L2D: &[u8] = b"cipher_l2d";
57const LABEL_CIPHER_D2L: &[u8] = b"cipher_d2l";
58const LABEL_CONFIRMATION_L2D: &[u8] = b"confirmation_l2d";
59const LABEL_CONFIRMATION_D2L: &[u8] = b"confirmation_d2l";
60
61#[cfg_attr(test, derive(Debug, PartialEq))]
64pub struct Syn<S: Signature> {
65 time_ms: u64,
66 epk: EphemeralPublicKey,
67 sig: S,
68}
69
70impl<S: Signature> FixedSize for Syn<S> {
71 const SIZE: usize = u64::SIZE + EphemeralPublicKey::SIZE + S::SIZE;
72}
73
74impl<S: Signature + Write> Write for Syn<S> {
75 fn write(&self, buf: &mut impl bytes::BufMut) {
76 self.time_ms.write(buf);
77 self.epk.write(buf);
78 self.sig.write(buf);
79 }
80}
81
82impl<S: Signature + Read> Read for Syn<S> {
83 type Cfg = S::Cfg;
84
85 fn read_cfg(
86 buf: &mut impl bytes::Buf,
87 cfg: &Self::Cfg,
88 ) -> Result<Self, commonware_codec::Error> {
89 Ok(Self {
90 time_ms: ReadExt::read(buf)?,
91 epk: ReadExt::read(buf)?,
92 sig: Read::read_cfg(buf, cfg)?,
93 })
94 }
95}
96
97#[cfg(feature = "arbitrary")]
98impl<S: Signature> arbitrary::Arbitrary<'_> for Syn<S>
99where
100 S: for<'a> arbitrary::Arbitrary<'a>,
101{
102 fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
103 Ok(Self {
104 time_ms: u.arbitrary()?,
105 epk: u.arbitrary()?,
106 sig: u.arbitrary()?,
107 })
108 }
109}
110
111#[cfg_attr(test, derive(Debug, PartialEq))]
114pub struct SynAck<S: Signature> {
115 time_ms: u64,
116 epk: EphemeralPublicKey,
117 sig: S,
118 confirmation: Summary,
119}
120
121impl<S: Signature> FixedSize for SynAck<S> {
122 const SIZE: usize = u64::SIZE + EphemeralPublicKey::SIZE + S::SIZE + Summary::SIZE;
123}
124
125impl<S: Signature + Write> Write for SynAck<S> {
126 fn write(&self, buf: &mut impl bytes::BufMut) {
127 self.time_ms.write(buf);
128 self.epk.write(buf);
129 self.sig.write(buf);
130 self.confirmation.write(buf);
131 }
132}
133
134impl<S: Signature + Read> Read for SynAck<S> {
135 type Cfg = S::Cfg;
136
137 fn read_cfg(
138 buf: &mut impl bytes::Buf,
139 cfg: &Self::Cfg,
140 ) -> Result<Self, commonware_codec::Error> {
141 Ok(Self {
142 time_ms: ReadExt::read(buf)?,
143 epk: ReadExt::read(buf)?,
144 sig: Read::read_cfg(buf, cfg)?,
145 confirmation: ReadExt::read(buf)?,
146 })
147 }
148}
149
150#[cfg(feature = "arbitrary")]
151impl<S: Signature> arbitrary::Arbitrary<'_> for SynAck<S>
152where
153 S: for<'a> arbitrary::Arbitrary<'a>,
154{
155 fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
156 Ok(Self {
157 time_ms: u.arbitrary()?,
158 epk: u.arbitrary()?,
159 sig: u.arbitrary()?,
160 confirmation: u.arbitrary()?,
161 })
162 }
163}
164
165#[cfg_attr(test, derive(PartialEq))]
168#[cfg_attr(feature = "arbitrary", derive(Debug, arbitrary::Arbitrary))]
169pub struct Ack {
170 confirmation: Summary,
171}
172
173impl FixedSize for Ack {
174 const SIZE: usize = Summary::SIZE;
175}
176
177impl Write for Ack {
178 fn write(&self, buf: &mut impl bytes::BufMut) {
179 self.confirmation.write(buf);
180 }
181}
182
183impl Read for Ack {
184 type Cfg = ();
185
186 fn read_cfg(
187 buf: &mut impl bytes::Buf,
188 _cfg: &Self::Cfg,
189 ) -> Result<Self, commonware_codec::Error> {
190 Ok(Self {
191 confirmation: ReadExt::read(buf)?,
192 })
193 }
194}
195
196pub struct DialState<P> {
199 esk: SecretKey,
200 peer_identity: P,
201 transcript: Transcript,
202 ok_timestamps: Range<u64>,
203}
204
205pub struct ListenState {
208 confirmation: Summary,
209 send: SendCipher,
210 recv: RecvCipher,
211}
212
213pub struct Context<S, P> {
216 transcript: Transcript,
217 current_time: u64,
218 ok_timestamps: Range<u64>,
219 my_identity: S,
220 peer_identity: P,
221}
222
223impl<S, P> Context<S, P> {
224 pub fn new(
226 base: &Transcript,
227 current_time_ms: u64,
228 ok_timestamps: Range<u64>,
229 my_identity: S,
230 peer_identity: P,
231 ) -> Self {
232 Self {
233 transcript: base.fork(NAMESPACE),
234 current_time: current_time_ms,
235 ok_timestamps,
236 my_identity,
237 peer_identity,
238 }
239 }
240}
241
242pub fn dial_start<S: Signer, P: PublicKey>(
245 rng: impl CryptoRngCore,
246 ctx: Context<S, P>,
247) -> (DialState<P>, Syn<<S as Signer>::Signature>) {
248 let Context {
249 current_time,
250 ok_timestamps,
251 my_identity,
252 peer_identity,
253 mut transcript,
254 } = ctx;
255 let esk = SecretKey::new(rng);
256 let epk = esk.public();
257 let sig = transcript
258 .commit(current_time.encode())
259 .commit(peer_identity.encode())
260 .commit(epk.encode())
261 .sign(&my_identity);
262 transcript.commit(my_identity.public_key().encode());
263 (
264 DialState {
265 esk,
266 peer_identity,
267 transcript,
268 ok_timestamps,
269 },
270 Syn {
271 time_ms: current_time,
272 epk,
273 sig,
274 },
275 )
276}
277
278pub fn dial_end<P: PublicKey>(
281 state: DialState<P>,
282 msg: SynAck<<P as Verifier>::Signature>,
283) -> Result<(Ack, SendCipher, RecvCipher), Error> {
284 let DialState {
285 esk,
286 peer_identity,
287 mut transcript,
288 ok_timestamps,
289 } = state;
290 if !ok_timestamps.contains(&msg.time_ms) {
291 return Err(Error::InvalidTimestamp(msg.time_ms, ok_timestamps));
292 }
293 if !transcript
294 .commit(msg.time_ms.encode())
295 .commit(msg.epk.encode())
296 .verify(&peer_identity, &msg.sig)
297 {
298 return Err(Error::HandshakeFailed);
299 }
300 let Some(secret) = esk.exchange(&msg.epk) else {
301 return Err(Error::HandshakeFailed);
302 };
303 transcript.commit(secret.as_ref());
304 let recv = RecvCipher::new(transcript.noise(LABEL_CIPHER_L2D));
305 let send = SendCipher::new(transcript.noise(LABEL_CIPHER_D2L));
306 let confirmation_l2d = transcript.fork(LABEL_CONFIRMATION_L2D).summarize();
307 let confirmation_d2l = transcript.fork(LABEL_CONFIRMATION_D2L).summarize();
308 if msg.confirmation != confirmation_l2d {
309 return Err(Error::HandshakeFailed);
310 }
311
312 Ok((
313 Ack {
314 confirmation: confirmation_d2l,
315 },
316 send,
317 recv,
318 ))
319}
320
321pub fn listen_start<S: Signer, P: PublicKey>(
324 rng: &mut impl CryptoRngCore,
325 ctx: Context<S, P>,
326 msg: Syn<<P as Verifier>::Signature>,
327) -> Result<(ListenState, SynAck<<S as Signer>::Signature>), Error> {
328 let Context {
329 current_time,
330 my_identity,
331 peer_identity,
332 ok_timestamps,
333 mut transcript,
334 } = ctx;
335 if !ok_timestamps.contains(&msg.time_ms) {
336 return Err(Error::InvalidTimestamp(msg.time_ms, ok_timestamps));
337 }
338 if !transcript
339 .commit(msg.time_ms.encode())
340 .commit(my_identity.public_key().encode())
341 .commit(msg.epk.encode())
342 .verify(&peer_identity, &msg.sig)
343 {
344 return Err(Error::HandshakeFailed);
345 }
346 let esk = SecretKey::new(rng);
347 let epk = esk.public();
348 let sig = transcript
349 .commit(peer_identity.encode())
350 .commit(current_time.encode())
351 .commit(epk.encode())
352 .sign(&my_identity);
353 let Some(secret) = esk.exchange(&msg.epk) else {
354 return Err(Error::HandshakeFailed);
355 };
356 transcript.commit(secret.as_ref());
357 let send = SendCipher::new(transcript.noise(LABEL_CIPHER_L2D));
358 let recv = RecvCipher::new(transcript.noise(LABEL_CIPHER_D2L));
359 let confirmation_l2d = transcript.fork(LABEL_CONFIRMATION_L2D).summarize();
360 let confirmation_d2l = transcript.fork(LABEL_CONFIRMATION_D2L).summarize();
361
362 Ok((
363 ListenState {
364 confirmation: confirmation_d2l,
365 send,
366 recv,
367 },
368 SynAck {
369 time_ms: current_time,
370 epk,
371 sig,
372 confirmation: confirmation_l2d,
373 },
374 ))
375}
376
377pub fn listen_end(state: ListenState, msg: Ack) -> Result<(SendCipher, RecvCipher), Error> {
380 if msg.confirmation != state.confirmation {
381 return Err(Error::HandshakeFailed);
382 }
383 Ok((state.send, state.recv))
384}
385
386#[cfg(test)]
387mod test {
388 use super::*;
389 use crate::{ed25519::PrivateKey, transcript::Transcript, Signer};
390 use commonware_codec::{Codec, DecodeExt};
391 use commonware_math::algebra::Random;
392 use rand::SeedableRng;
393 use rand_chacha::ChaCha8Rng;
394
395 fn test_encode_roundtrip<T: Codec<Cfg = ()> + PartialEq>(value: &T) {
396 assert!(value == &<T as DecodeExt<_>>::decode(value.encode()).unwrap());
397 }
398
399 #[test]
400 fn test_can_setup_and_send_messages() -> Result<(), Error> {
401 let mut rng = ChaCha8Rng::seed_from_u64(0);
402 let dialer_crypto = PrivateKey::random(&mut rng);
403 let listener_crypto = PrivateKey::random(&mut rng);
404
405 let (d_state, msg1) = dial_start(
406 &mut rng,
407 Context::new(
408 &Transcript::new(b"test_namespace"),
409 0,
410 0..1,
411 dialer_crypto.clone(),
412 listener_crypto.public_key(),
413 ),
414 );
415 test_encode_roundtrip(&msg1);
416 let (l_state, msg2) = listen_start(
417 &mut rng,
418 Context::new(
419 &Transcript::new(b"test_namespace"),
420 0,
421 0..1,
422 listener_crypto,
423 dialer_crypto.public_key(),
424 ),
425 msg1,
426 )?;
427 test_encode_roundtrip(&msg2);
428 let (msg3, mut d_send, mut d_recv) = dial_end(d_state, msg2)?;
429 test_encode_roundtrip(&msg3);
430 let (mut l_send, mut l_recv) = listen_end(l_state, msg3)?;
431
432 let m1: &'static [u8] = b"message 1";
433
434 let c1 = d_send.send(m1)?;
435 let m1_prime = l_recv.recv(&c1)?;
436 assert_eq!(m1, &m1_prime);
437
438 let m2: &'static [u8] = b"message 2";
439 let c2 = l_send.send(m2)?;
440 let m2_prime = d_recv.recv(&c2)?;
441 assert_eq!(m2, &m2_prime);
442
443 Ok(())
444 }
445
446 #[test]
447 fn test_mismatched_namespace_fails() {
448 let mut rng = ChaCha8Rng::seed_from_u64(0);
449 let dialer_crypto = PrivateKey::random(&mut rng);
450 let listener_crypto = PrivateKey::random(&mut rng);
451
452 let (_, msg1) = dial_start(
453 &mut rng,
454 Context::new(
455 &Transcript::new(b"namespace_a"),
456 0,
457 0..1,
458 dialer_crypto.clone(),
459 listener_crypto.public_key(),
460 ),
461 );
462
463 let result = listen_start(
464 &mut rng,
465 Context::new(
466 &Transcript::new(b"namespace_b"),
467 0,
468 0..1,
469 listener_crypto,
470 dialer_crypto.public_key(),
471 ),
472 msg1,
473 );
474
475 assert!(matches!(result, Err(Error::HandshakeFailed)));
476 }
477
478 #[cfg(feature = "arbitrary")]
479 mod conformance {
480 use super::*;
481 use commonware_codec::conformance::CodecConformance;
482
483 commonware_conformance::conformance_tests! {
484 CodecConformance<Syn<crate::ed25519::Signature>>,
485 CodecConformance<SynAck<crate::ed25519::Signature>>,
486 CodecConformance<Ack>,
487 }
488 }
489}