Skip to main content

commonware_cryptography/
handshake.rs

1//! This module provides an authenticated key exchange protocol, or handshake.
2//!
3//! # Design
4//!
5//! The **dialer** and the **listener** both have a public identity, known to each other in advance.
6//! The goal of the handshake is to establish a shared, encrypted, and authenticated communication
7//! channel between these two parties. No third party should be able to read messages, or send
8//! messates along the channel.
9//!
10//! A three-message handshake is used to authenticate peers and establish a shared secret. The
11//! **dialer** initiates the connection, and the **listener** responds.
12//!
13//! [Syn] The dialer starts by sending a signed message with their ephemeral key.
14//!
15//! [SynAck] The listener responds by sending back their ephemeral key, along with a signature over the
16//! protocol transcript thus far. They can also derive a shared secret, which they use to generate
17//! a confirmation tag, also sent to the dialer.
18//!
19//! [Ack] The dialer verifies the signed message, then derives the same secret, and uses
20//! that to send their own confirmation back to the listener.
21//!
22//! The listener then verifies this confirmation.
23//!
24//! The shared secret can then be used to derive to AEAD keys, for the sending data ([SendCipher])
25//! and receiving data ([RecvCipher]). These use ChaCha20-Poly1305 as the AEAD. Each direction has
26//! a 12 byte counter to used as a nonce, with every call to [SendCipher::send] on one end,
27//! or [RecvCipher::recv] on the other end incrementing this counter. This guarantees that if
28//! a message is successfully received, then it was delivered in order. Re-ordering messages on
29//! the wire will have the effect of producing errors on the receiving end, but not of producing
30//! successful messages in a different order.
31//!
32//! # Security Features
33//!
34//! The protocol includes timestamp validation to protect against replay attacks and clock skew:
35//! - Messages with timestamps too old are rejected to prevent replay attacks
36//! - Messages with timestamps too far in the future are rejected to safeguard against clock skew
37use crate::{
38    transcript::{Summary, Transcript},
39    PublicKey, Signature, Signer, Verifier,
40};
41use commonware_codec::{Encode, FixedSize, Read, ReadExt, Write};
42use core::ops::Range;
43use rand_core::CryptoRngCore;
44
45mod error;
46pub use error::Error;
47
48mod key_exchange;
49use key_exchange::{EphemeralPublicKey, SecretKey};
50
51mod cipher;
52pub use cipher::{RecvCipher, SendCipher, TAG_SIZE};
53
54#[cfg(all(test, feature = "arbitrary"))]
55mod conformance;
56
57const NAMESPACE: &[u8] = b"_COMMONWARE_CRYPTOGRAPHY_HANDSHAKE";
58const LABEL_CIPHER_L2D: &[u8] = b"cipher_l2d";
59const LABEL_CIPHER_D2L: &[u8] = b"cipher_d2l";
60const LABEL_CONFIRMATION_L2D: &[u8] = b"confirmation_l2d";
61const LABEL_CONFIRMATION_D2L: &[u8] = b"confirmation_d2l";
62
63/// First handshake message sent by the dialer.
64/// Contains dialer's ephemeral key and timestamp signature.
65#[cfg_attr(test, derive(Debug, PartialEq))]
66pub struct Syn<S: Signature> {
67    time_ms: u64,
68    epk: EphemeralPublicKey,
69    sig: S,
70}
71
72impl<S: Signature> FixedSize for Syn<S> {
73    const SIZE: usize = u64::SIZE + EphemeralPublicKey::SIZE + S::SIZE;
74}
75
76impl<S: Signature + Write> Write for Syn<S> {
77    fn write(&self, buf: &mut impl bytes::BufMut) {
78        self.time_ms.write(buf);
79        self.epk.write(buf);
80        self.sig.write(buf);
81    }
82}
83
84impl<S: Signature + Read> Read for Syn<S> {
85    type Cfg = S::Cfg;
86
87    fn read_cfg(
88        buf: &mut impl bytes::Buf,
89        cfg: &Self::Cfg,
90    ) -> Result<Self, commonware_codec::Error> {
91        Ok(Self {
92            time_ms: ReadExt::read(buf)?,
93            epk: ReadExt::read(buf)?,
94            sig: Read::read_cfg(buf, cfg)?,
95        })
96    }
97}
98
99#[cfg(feature = "arbitrary")]
100impl<S: Signature> arbitrary::Arbitrary<'_> for Syn<S>
101where
102    S: for<'a> arbitrary::Arbitrary<'a>,
103{
104    fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
105        Ok(Self {
106            time_ms: u.arbitrary()?,
107            epk: u.arbitrary()?,
108            sig: u.arbitrary()?,
109        })
110    }
111}
112
113/// Second handshake message sent by the listener.
114/// Contains listener's ephemeral key, signature, and confirmation tag.
115#[cfg_attr(test, derive(Debug, PartialEq))]
116pub struct SynAck<S: Signature> {
117    time_ms: u64,
118    epk: EphemeralPublicKey,
119    sig: S,
120    confirmation: Summary,
121}
122
123impl<S: Signature> FixedSize for SynAck<S> {
124    const SIZE: usize = u64::SIZE + EphemeralPublicKey::SIZE + S::SIZE + Summary::SIZE;
125}
126
127impl<S: Signature + Write> Write for SynAck<S> {
128    fn write(&self, buf: &mut impl bytes::BufMut) {
129        self.time_ms.write(buf);
130        self.epk.write(buf);
131        self.sig.write(buf);
132        self.confirmation.write(buf);
133    }
134}
135
136impl<S: Signature + Read> Read for SynAck<S> {
137    type Cfg = S::Cfg;
138
139    fn read_cfg(
140        buf: &mut impl bytes::Buf,
141        cfg: &Self::Cfg,
142    ) -> Result<Self, commonware_codec::Error> {
143        Ok(Self {
144            time_ms: ReadExt::read(buf)?,
145            epk: ReadExt::read(buf)?,
146            sig: Read::read_cfg(buf, cfg)?,
147            confirmation: ReadExt::read(buf)?,
148        })
149    }
150}
151
152#[cfg(feature = "arbitrary")]
153impl<S: Signature> arbitrary::Arbitrary<'_> for SynAck<S>
154where
155    S: for<'a> arbitrary::Arbitrary<'a>,
156{
157    fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
158        Ok(Self {
159            time_ms: u.arbitrary()?,
160            epk: u.arbitrary()?,
161            sig: u.arbitrary()?,
162            confirmation: u.arbitrary()?,
163        })
164    }
165}
166
167/// Third handshake message sent by the dialer.
168/// Contains dialer's confirmation tag to complete the handshake.
169#[cfg_attr(test, derive(PartialEq))]
170#[cfg_attr(feature = "arbitrary", derive(Debug, arbitrary::Arbitrary))]
171pub struct Ack {
172    confirmation: Summary,
173}
174
175impl FixedSize for Ack {
176    const SIZE: usize = Summary::SIZE;
177}
178
179impl Write for Ack {
180    fn write(&self, buf: &mut impl bytes::BufMut) {
181        self.confirmation.write(buf);
182    }
183}
184
185impl Read for Ack {
186    type Cfg = ();
187
188    fn read_cfg(
189        buf: &mut impl bytes::Buf,
190        _cfg: &Self::Cfg,
191    ) -> Result<Self, commonware_codec::Error> {
192        Ok(Self {
193            confirmation: ReadExt::read(buf)?,
194        })
195    }
196}
197
198/// State maintained by the dialer during handshake.
199/// Tracks ephemeral secret, peer identity, and protocol transcript.
200pub struct DialState<P> {
201    esk: SecretKey,
202    peer_identity: P,
203    transcript: Transcript,
204    ok_timestamps: Range<u64>,
205}
206
207/// State maintained by the listener during handshake.
208/// Tracks expected confirmation and derived ciphers.
209pub struct ListenState {
210    confirmation: Summary,
211    send: SendCipher,
212    recv: RecvCipher,
213}
214
215/// Handshake context containing timing and identity information.
216/// Used by both dialer and listener to initialize handshake state.
217pub struct Context<S, P> {
218    transcript: Transcript,
219    current_time: u64,
220    ok_timestamps: Range<u64>,
221    my_identity: S,
222    peer_identity: P,
223}
224
225impl<S, P> Context<S, P> {
226    /// Creates a new handshake context.
227    pub fn new(
228        base: &Transcript,
229        current_time_ms: u64,
230        ok_timestamps: Range<u64>,
231        my_identity: S,
232        peer_identity: P,
233    ) -> Self {
234        Self {
235            transcript: base.fork(NAMESPACE),
236            current_time: current_time_ms,
237            ok_timestamps,
238            my_identity,
239            peer_identity,
240        }
241    }
242}
243
244/// Initiates a handshake as the dialer.
245/// Returns the dialer state and the first message to send.
246pub fn dial_start<S: Signer, P: PublicKey>(
247    rng: impl CryptoRngCore,
248    ctx: Context<S, P>,
249) -> (DialState<P>, Syn<<S as Signer>::Signature>) {
250    let Context {
251        current_time,
252        ok_timestamps,
253        my_identity,
254        peer_identity,
255        mut transcript,
256    } = ctx;
257    let esk = SecretKey::new(rng);
258    let epk = esk.public();
259    let sig = transcript
260        .commit(current_time.encode())
261        .commit(peer_identity.encode())
262        .commit(epk.encode())
263        .sign(&my_identity);
264    transcript.commit(my_identity.public_key().encode());
265    (
266        DialState {
267            esk,
268            peer_identity,
269            transcript,
270            ok_timestamps,
271        },
272        Syn {
273            time_ms: current_time,
274            epk,
275            sig,
276        },
277    )
278}
279
280/// Completes a handshake as the dialer.
281/// Verifies the listener's response and returns final message and ciphers.
282pub fn dial_end<P: PublicKey>(
283    state: DialState<P>,
284    msg: SynAck<<P as Verifier>::Signature>,
285) -> Result<(Ack, SendCipher, RecvCipher), Error> {
286    let DialState {
287        esk,
288        peer_identity,
289        mut transcript,
290        ok_timestamps,
291    } = state;
292    if !ok_timestamps.contains(&msg.time_ms) {
293        return Err(Error::InvalidTimestamp(msg.time_ms, ok_timestamps));
294    }
295    if !transcript
296        .commit(msg.time_ms.encode())
297        .commit(msg.epk.encode())
298        .verify(&peer_identity, &msg.sig)
299    {
300        return Err(Error::HandshakeFailed);
301    }
302    let Some(shared) = esk.exchange(&msg.epk) else {
303        return Err(Error::HandshakeFailed);
304    };
305    shared
306        .secret
307        .expose(|secret| transcript.commit(secret.as_ref()));
308    let recv = RecvCipher::new(transcript.noise(LABEL_CIPHER_L2D));
309    let send = SendCipher::new(transcript.noise(LABEL_CIPHER_D2L));
310    let confirmation_l2d = transcript.fork(LABEL_CONFIRMATION_L2D).summarize();
311    let confirmation_d2l = transcript.fork(LABEL_CONFIRMATION_D2L).summarize();
312    if msg.confirmation != confirmation_l2d {
313        return Err(Error::HandshakeFailed);
314    }
315
316    Ok((
317        Ack {
318            confirmation: confirmation_d2l,
319        },
320        send,
321        recv,
322    ))
323}
324
325/// Processes the first handshake message as the listener.
326/// Verifies the dialer's message and returns state and response.
327pub fn listen_start<S: Signer, P: PublicKey>(
328    rng: &mut impl CryptoRngCore,
329    ctx: Context<S, P>,
330    msg: Syn<<P as Verifier>::Signature>,
331) -> Result<(ListenState, SynAck<<S as Signer>::Signature>), Error> {
332    let Context {
333        current_time,
334        my_identity,
335        peer_identity,
336        ok_timestamps,
337        mut transcript,
338    } = ctx;
339    if !ok_timestamps.contains(&msg.time_ms) {
340        return Err(Error::InvalidTimestamp(msg.time_ms, ok_timestamps));
341    }
342    if !transcript
343        .commit(msg.time_ms.encode())
344        .commit(my_identity.public_key().encode())
345        .commit(msg.epk.encode())
346        .verify(&peer_identity, &msg.sig)
347    {
348        return Err(Error::HandshakeFailed);
349    }
350    let esk = SecretKey::new(rng);
351    let epk = esk.public();
352    let sig = transcript
353        .commit(peer_identity.encode())
354        .commit(current_time.encode())
355        .commit(epk.encode())
356        .sign(&my_identity);
357    let Some(shared) = esk.exchange(&msg.epk) else {
358        return Err(Error::HandshakeFailed);
359    };
360    shared
361        .secret
362        .expose(|secret| transcript.commit(secret.as_ref()));
363    let send = SendCipher::new(transcript.noise(LABEL_CIPHER_L2D));
364    let recv = RecvCipher::new(transcript.noise(LABEL_CIPHER_D2L));
365    let confirmation_l2d = transcript.fork(LABEL_CONFIRMATION_L2D).summarize();
366    let confirmation_d2l = transcript.fork(LABEL_CONFIRMATION_D2L).summarize();
367
368    Ok((
369        ListenState {
370            confirmation: confirmation_d2l,
371            send,
372            recv,
373        },
374        SynAck {
375            time_ms: current_time,
376            epk,
377            sig,
378            confirmation: confirmation_l2d,
379        },
380    ))
381}
382
383/// Completes the handshake as the listener.
384/// Verifies the dialer's confirmation and returns established ciphers.
385pub fn listen_end(state: ListenState, msg: Ack) -> Result<(SendCipher, RecvCipher), Error> {
386    if msg.confirmation != state.confirmation {
387        return Err(Error::HandshakeFailed);
388    }
389    Ok((state.send, state.recv))
390}
391
392#[cfg(test)]
393mod test {
394    use super::*;
395    use crate::{ed25519::PrivateKey, transcript::Transcript, Signer};
396    use commonware_codec::{Codec, DecodeExt};
397    use commonware_math::algebra::Random;
398    use commonware_utils::test_rng;
399
400    fn test_encode_roundtrip<T: Codec<Cfg = ()> + PartialEq>(value: &T) {
401        assert!(value == &<T as DecodeExt<_>>::decode(value.encode()).unwrap());
402    }
403
404    #[test]
405    fn test_can_setup_and_send_messages() -> Result<(), Error> {
406        let mut rng = test_rng();
407        let dialer_crypto = PrivateKey::random(&mut rng);
408        let listener_crypto = PrivateKey::random(&mut rng);
409
410        let (d_state, msg1) = dial_start(
411            &mut rng,
412            Context::new(
413                &Transcript::new(b"test_namespace"),
414                0,
415                0..1,
416                dialer_crypto.clone(),
417                listener_crypto.public_key(),
418            ),
419        );
420        test_encode_roundtrip(&msg1);
421        let (l_state, msg2) = listen_start(
422            &mut rng,
423            Context::new(
424                &Transcript::new(b"test_namespace"),
425                0,
426                0..1,
427                listener_crypto,
428                dialer_crypto.public_key(),
429            ),
430            msg1,
431        )?;
432        test_encode_roundtrip(&msg2);
433        let (msg3, mut d_send, mut d_recv) = dial_end(d_state, msg2)?;
434        test_encode_roundtrip(&msg3);
435        let (mut l_send, mut l_recv) = listen_end(l_state, msg3)?;
436
437        let m1: &'static [u8] = b"message 1";
438
439        let c1 = d_send.send(m1)?;
440        let m1_prime = l_recv.recv(&c1)?;
441        assert_eq!(m1, &m1_prime);
442
443        let m2: &'static [u8] = b"message 2";
444        let c2 = l_send.send(m2)?;
445        let m2_prime = d_recv.recv(&c2)?;
446        assert_eq!(m2, &m2_prime);
447
448        Ok(())
449    }
450
451    #[test]
452    fn test_mismatched_namespace_fails() {
453        let mut rng = test_rng();
454        let dialer_crypto = PrivateKey::random(&mut rng);
455        let listener_crypto = PrivateKey::random(&mut rng);
456
457        let (_, msg1) = dial_start(
458            &mut rng,
459            Context::new(
460                &Transcript::new(b"namespace_a"),
461                0,
462                0..1,
463                dialer_crypto.clone(),
464                listener_crypto.public_key(),
465            ),
466        );
467
468        let result = listen_start(
469            &mut rng,
470            Context::new(
471                &Transcript::new(b"namespace_b"),
472                0,
473                0..1,
474                listener_crypto,
475                dialer_crypto.public_key(),
476            ),
477            msg1,
478        );
479
480        assert!(matches!(result, Err(Error::HandshakeFailed)));
481    }
482
483    #[cfg(feature = "arbitrary")]
484    mod conformance {
485        use super::*;
486        use commonware_codec::conformance::CodecConformance;
487
488        commonware_conformance::conformance_tests! {
489            CodecConformance<Syn<crate::ed25519::Signature>>,
490            CodecConformance<SynAck<crate::ed25519::Signature>>,
491            CodecConformance<Ack>,
492        }
493    }
494}