1use crate::{
36 transcript::{Summary, Transcript},
37 PublicKey, Signature, Signer, Verifier,
38};
39use commonware_codec::{Encode, FixedSize, Read, ReadExt, Write};
40use core::ops::Range;
41use rand_core::CryptoRngCore;
42
43mod error;
44pub use error::Error;
45
46mod key_exchange;
47use key_exchange::{EphemeralPublicKey, SecretKey};
48
49mod cipher;
50pub use cipher::{RecvCipher, SendCipher, CIPHERTEXT_OVERHEAD};
51
52const NAMESPACE: &[u8] = b"commonware/handshake";
53const LABEL_CIPHER_L2D: &[u8] = b"cipher_l2d";
54const LABEL_CIPHER_D2L: &[u8] = b"cipher_d2l";
55const LABEL_CONFIRMATION_L2D: &[u8] = b"confirmation_l2d";
56const LABEL_CONFIRMATION_D2L: &[u8] = b"confirmation_d2l";
57
58#[cfg_attr(test, derive(PartialEq))]
61pub struct Syn<S: Signature> {
62 time_ms: u64,
63 epk: EphemeralPublicKey,
64 sig: S,
65}
66
67impl<S: Signature> FixedSize for Syn<S> {
68 const SIZE: usize = u64::SIZE + EphemeralPublicKey::SIZE + S::SIZE;
69}
70
71impl<S: Signature + Write> Write for Syn<S> {
72 fn write(&self, buf: &mut impl bytes::BufMut) {
73 self.time_ms.write(buf);
74 self.epk.write(buf);
75 self.sig.write(buf);
76 }
77}
78
79impl<S: Signature + Read> Read for Syn<S> {
80 type Cfg = S::Cfg;
81
82 fn read_cfg(
83 buf: &mut impl bytes::Buf,
84 cfg: &Self::Cfg,
85 ) -> Result<Self, commonware_codec::Error> {
86 Ok(Self {
87 time_ms: ReadExt::read(buf)?,
88 epk: ReadExt::read(buf)?,
89 sig: Read::read_cfg(buf, cfg)?,
90 })
91 }
92}
93
94#[cfg_attr(test, derive(PartialEq))]
97pub struct SynAck<S: Signature> {
98 time_ms: u64,
99 epk: EphemeralPublicKey,
100 sig: S,
101 confirmation: Summary,
102}
103
104impl<S: Signature> FixedSize for SynAck<S> {
105 const SIZE: usize = u64::SIZE + EphemeralPublicKey::SIZE + S::SIZE + Summary::SIZE;
106}
107
108impl<S: Signature + Write> Write for SynAck<S> {
109 fn write(&self, buf: &mut impl bytes::BufMut) {
110 self.time_ms.write(buf);
111 self.epk.write(buf);
112 self.sig.write(buf);
113 self.confirmation.write(buf);
114 }
115}
116
117impl<S: Signature + Read> Read for SynAck<S> {
118 type Cfg = S::Cfg;
119
120 fn read_cfg(
121 buf: &mut impl bytes::Buf,
122 cfg: &Self::Cfg,
123 ) -> Result<Self, commonware_codec::Error> {
124 Ok(Self {
125 time_ms: ReadExt::read(buf)?,
126 epk: ReadExt::read(buf)?,
127 sig: Read::read_cfg(buf, cfg)?,
128 confirmation: ReadExt::read(buf)?,
129 })
130 }
131}
132
133#[cfg_attr(test, derive(PartialEq))]
136pub struct Ack {
137 confirmation: Summary,
138}
139
140impl FixedSize for Ack {
141 const SIZE: usize = Summary::SIZE;
142}
143
144impl Write for Ack {
145 fn write(&self, buf: &mut impl bytes::BufMut) {
146 self.confirmation.write(buf);
147 }
148}
149
150impl Read for Ack {
151 type Cfg = ();
152
153 fn read_cfg(
154 buf: &mut impl bytes::Buf,
155 _cfg: &Self::Cfg,
156 ) -> Result<Self, commonware_codec::Error> {
157 Ok(Self {
158 confirmation: ReadExt::read(buf)?,
159 })
160 }
161}
162
163pub struct DialState<P> {
166 esk: SecretKey,
167 peer_identity: P,
168 transcript: Transcript,
169 ok_timestamps: Range<u64>,
170}
171
172pub struct ListenState {
175 confirmation: Summary,
176 send: SendCipher,
177 recv: RecvCipher,
178}
179
180pub struct Context<S, P> {
183 current_time: u64,
184 ok_timestamps: Range<u64>,
185 my_identity: S,
186 peer_identity: P,
187}
188
189impl<S, P> Context<S, P> {
190 pub fn new(
192 current_time_ms: u64,
193 ok_timestamps: Range<u64>,
194 my_identity: S,
195 peer_identity: P,
196 ) -> Self {
197 Self {
198 current_time: current_time_ms,
199 ok_timestamps,
200 my_identity,
201 peer_identity,
202 }
203 }
204}
205
206pub fn dial_start<S: Signer, P: PublicKey>(
209 rng: impl CryptoRngCore,
210 ctx: Context<S, P>,
211) -> (DialState<P>, Syn<<S as Signer>::Signature>) {
212 let Context {
213 current_time,
214 ok_timestamps,
215 my_identity,
216 peer_identity,
217 } = ctx;
218 let esk = SecretKey::new(rng);
219 let epk = esk.public();
220 let mut transcript = Transcript::new(NAMESPACE);
221 let sig = transcript
222 .commit(current_time.encode())
223 .commit(peer_identity.encode())
224 .commit(epk.encode())
225 .sign(&my_identity);
226 transcript.commit(my_identity.public_key().encode());
227 (
228 DialState {
229 esk,
230 peer_identity,
231 transcript,
232 ok_timestamps,
233 },
234 Syn {
235 time_ms: current_time,
236 epk,
237 sig,
238 },
239 )
240}
241
242pub fn dial_end<P: PublicKey>(
245 state: DialState<P>,
246 msg: SynAck<<P as Verifier>::Signature>,
247) -> Result<(Ack, SendCipher, RecvCipher), Error> {
248 let DialState {
249 esk,
250 peer_identity,
251 mut transcript,
252 ok_timestamps,
253 } = state;
254 if !ok_timestamps.contains(&msg.time_ms) {
255 return Err(Error::InvalidTimestamp(msg.time_ms, ok_timestamps));
256 }
257 if !transcript
258 .commit(msg.time_ms.encode())
259 .commit(msg.epk.encode())
260 .verify(&peer_identity, &msg.sig)
261 {
262 return Err(Error::HandshakeFailed);
263 }
264 let Some(secret) = esk.exchange(&msg.epk) else {
265 return Err(Error::HandshakeFailed);
266 };
267 transcript.commit(secret.as_ref());
268 let recv = RecvCipher::new(transcript.noise(LABEL_CIPHER_L2D));
269 let send = SendCipher::new(transcript.noise(LABEL_CIPHER_D2L));
270 let confirmation_l2d = transcript.fork(LABEL_CONFIRMATION_L2D).summarize();
271 let confirmation_d2l = transcript.fork(LABEL_CONFIRMATION_D2L).summarize();
272 if msg.confirmation != confirmation_l2d {
273 return Err(Error::HandshakeFailed);
274 }
275
276 Ok((
277 Ack {
278 confirmation: confirmation_d2l,
279 },
280 send,
281 recv,
282 ))
283}
284
285pub fn listen_start<S: Signer, P: PublicKey>(
288 rng: &mut impl CryptoRngCore,
289 ctx: Context<S, P>,
290 msg: Syn<<P as Verifier>::Signature>,
291) -> Result<(ListenState, SynAck<<S as Signer>::Signature>), Error> {
292 let Context {
293 current_time,
294 my_identity,
295 peer_identity,
296 ok_timestamps,
297 } = ctx;
298 if !ok_timestamps.contains(&msg.time_ms) {
299 return Err(Error::InvalidTimestamp(msg.time_ms, ok_timestamps));
300 }
301 let mut transcript = Transcript::new(NAMESPACE);
302 if !transcript
303 .commit(msg.time_ms.encode())
304 .commit(my_identity.public_key().encode())
305 .commit(msg.epk.encode())
306 .verify(&peer_identity, &msg.sig)
307 {
308 return Err(Error::HandshakeFailed);
309 }
310 let esk = SecretKey::new(rng);
311 let epk = esk.public();
312 let sig = transcript
313 .commit(peer_identity.encode())
314 .commit(current_time.encode())
315 .commit(epk.encode())
316 .sign(&my_identity);
317 let Some(secret) = esk.exchange(&msg.epk) else {
318 return Err(Error::HandshakeFailed);
319 };
320 transcript.commit(secret.as_ref());
321 let send = SendCipher::new(transcript.noise(LABEL_CIPHER_L2D));
322 let recv = RecvCipher::new(transcript.noise(LABEL_CIPHER_D2L));
323 let confirmation_l2d = transcript.fork(LABEL_CONFIRMATION_L2D).summarize();
324 let confirmation_d2l = transcript.fork(LABEL_CONFIRMATION_D2L).summarize();
325
326 Ok((
327 ListenState {
328 confirmation: confirmation_d2l,
329 send,
330 recv,
331 },
332 SynAck {
333 time_ms: current_time,
334 epk,
335 sig,
336 confirmation: confirmation_l2d,
337 },
338 ))
339}
340
341pub fn listen_end(state: ListenState, msg: Ack) -> Result<(SendCipher, RecvCipher), Error> {
344 if msg.confirmation != state.confirmation {
345 return Err(Error::HandshakeFailed);
346 }
347 Ok((state.send, state.recv))
348}
349
350#[cfg(test)]
351mod test {
352 use super::*;
353 use crate::{ed25519::PrivateKey, PrivateKeyExt as _, Signer};
354 use commonware_codec::{Codec, DecodeExt};
355 use rand::SeedableRng;
356 use rand_chacha::ChaCha8Rng;
357
358 fn test_encode_roundtrip<T: Codec<Cfg = ()> + PartialEq>(value: &T) {
359 assert!(value == &<T as DecodeExt<_>>::decode(value.encode()).unwrap());
360 }
361
362 #[test]
363 fn test_can_setup_and_send_messages() -> Result<(), Error> {
364 let mut rng = ChaCha8Rng::seed_from_u64(0);
365 let dialer_crypto = PrivateKey::from_rng(&mut rng);
366 let listener_crypto = PrivateKey::from_rng(&mut rng);
367
368 let (d_state, msg1) = dial_start(
369 &mut rng,
370 Context {
371 current_time: 0,
372 ok_timestamps: 0..1,
373 my_identity: dialer_crypto.clone(),
374 peer_identity: listener_crypto.public_key(),
375 },
376 );
377 test_encode_roundtrip(&msg1);
378 let (l_state, msg2) = listen_start(
379 &mut rng,
380 Context {
381 current_time: 0,
382 ok_timestamps: 0..1,
383 my_identity: listener_crypto,
384 peer_identity: dialer_crypto.public_key(),
385 },
386 msg1,
387 )?;
388 test_encode_roundtrip(&msg2);
389 let (msg3, mut d_send, mut d_recv) = dial_end(d_state, msg2)?;
390 test_encode_roundtrip(&msg3);
391 let (mut l_send, mut l_recv) = listen_end(l_state, msg3)?;
392
393 let m1: &'static [u8] = b"message 1";
394
395 let c1 = d_send.send(m1)?;
396 let m1_prime = l_recv.recv(&c1)?;
397 assert_eq!(m1, &m1_prime);
398
399 let m2: &'static [u8] = b"message 2";
400 let c2 = l_send.send(m2)?;
401 let m2_prime = d_recv.recv(&c2)?;
402 assert_eq!(m2, &m2_prime);
403
404 Ok(())
405 }
406}