use crate::{
transcript::{Summary, Transcript},
PublicKey, Signature, Signer, Verifier,
};
use commonware_codec::{Encode, FixedSize, Read, ReadExt, Write};
use core::ops::Range;
use rand_core::CryptoRngCore;
mod error;
pub use error::Error;
mod key_exchange;
use key_exchange::{EphemeralPublicKey, SecretKey};
mod cipher;
pub use cipher::{RecvCipher, SendCipher, CIPHERTEXT_OVERHEAD};
const NAMESPACE: &[u8] = b"commonware/handshake";
const LABEL_CIPHER_L2D: &[u8] = b"cipher_l2d";
const LABEL_CIPHER_D2L: &[u8] = b"cipher_d2l";
const LABEL_CONFIRMATION_L2D: &[u8] = b"confirmation_l2d";
const LABEL_CONFIRMATION_D2L: &[u8] = b"confirmation_d2l";
#[cfg_attr(test, derive(PartialEq))]
pub struct Syn<S: Signature> {
time_ms: u64,
epk: EphemeralPublicKey,
sig: S,
}
impl<S: Signature> FixedSize for Syn<S> {
const SIZE: usize = u64::SIZE + EphemeralPublicKey::SIZE + S::SIZE;
}
impl<S: Signature + Write> Write for Syn<S> {
fn write(&self, buf: &mut impl bytes::BufMut) {
self.time_ms.write(buf);
self.epk.write(buf);
self.sig.write(buf);
}
}
impl<S: Signature + Read> Read for Syn<S> {
type Cfg = S::Cfg;
fn read_cfg(
buf: &mut impl bytes::Buf,
cfg: &Self::Cfg,
) -> Result<Self, commonware_codec::Error> {
Ok(Self {
time_ms: ReadExt::read(buf)?,
epk: ReadExt::read(buf)?,
sig: Read::read_cfg(buf, cfg)?,
})
}
}
#[cfg_attr(test, derive(PartialEq))]
pub struct SynAck<S: Signature> {
time_ms: u64,
epk: EphemeralPublicKey,
sig: S,
confirmation: Summary,
}
impl<S: Signature> FixedSize for SynAck<S> {
const SIZE: usize = u64::SIZE + EphemeralPublicKey::SIZE + S::SIZE + Summary::SIZE;
}
impl<S: Signature + Write> Write for SynAck<S> {
fn write(&self, buf: &mut impl bytes::BufMut) {
self.time_ms.write(buf);
self.epk.write(buf);
self.sig.write(buf);
self.confirmation.write(buf);
}
}
impl<S: Signature + Read> Read for SynAck<S> {
type Cfg = S::Cfg;
fn read_cfg(
buf: &mut impl bytes::Buf,
cfg: &Self::Cfg,
) -> Result<Self, commonware_codec::Error> {
Ok(Self {
time_ms: ReadExt::read(buf)?,
epk: ReadExt::read(buf)?,
sig: Read::read_cfg(buf, cfg)?,
confirmation: ReadExt::read(buf)?,
})
}
}
#[cfg_attr(test, derive(PartialEq))]
pub struct Ack {
confirmation: Summary,
}
impl FixedSize for Ack {
const SIZE: usize = Summary::SIZE;
}
impl Write for Ack {
fn write(&self, buf: &mut impl bytes::BufMut) {
self.confirmation.write(buf);
}
}
impl Read for Ack {
type Cfg = ();
fn read_cfg(
buf: &mut impl bytes::Buf,
_cfg: &Self::Cfg,
) -> Result<Self, commonware_codec::Error> {
Ok(Self {
confirmation: ReadExt::read(buf)?,
})
}
}
pub struct DialState<P> {
esk: SecretKey,
peer_identity: P,
transcript: Transcript,
ok_timestamps: Range<u64>,
}
pub struct ListenState {
confirmation: Summary,
send: SendCipher,
recv: RecvCipher,
}
pub struct Context<S, P> {
current_time: u64,
ok_timestamps: Range<u64>,
my_identity: S,
peer_identity: P,
}
impl<S, P> Context<S, P> {
pub fn new(
current_time_ms: u64,
ok_timestamps: Range<u64>,
my_identity: S,
peer_identity: P,
) -> Self {
Self {
current_time: current_time_ms,
ok_timestamps,
my_identity,
peer_identity,
}
}
}
pub fn dial_start<S: Signer, P: PublicKey>(
rng: impl CryptoRngCore,
ctx: Context<S, P>,
) -> (DialState<P>, Syn<<S as Signer>::Signature>) {
let Context {
current_time,
ok_timestamps,
my_identity,
peer_identity,
} = ctx;
let esk = SecretKey::new(rng);
let epk = esk.public();
let mut transcript = Transcript::new(NAMESPACE);
let sig = transcript
.commit(current_time.encode())
.commit(peer_identity.encode())
.commit(epk.encode())
.sign(&my_identity);
transcript.commit(my_identity.public_key().encode());
(
DialState {
esk,
peer_identity,
transcript,
ok_timestamps,
},
Syn {
time_ms: current_time,
epk,
sig,
},
)
}
pub fn dial_end<P: PublicKey>(
state: DialState<P>,
msg: SynAck<<P as Verifier>::Signature>,
) -> Result<(Ack, SendCipher, RecvCipher), Error> {
let DialState {
esk,
peer_identity,
mut transcript,
ok_timestamps,
} = state;
if !ok_timestamps.contains(&msg.time_ms) {
return Err(Error::InvalidTimestamp(msg.time_ms, ok_timestamps));
}
if !transcript
.commit(msg.time_ms.encode())
.commit(msg.epk.encode())
.verify(&peer_identity, &msg.sig)
{
return Err(Error::HandshakeFailed);
}
let Some(secret) = esk.exchange(&msg.epk) else {
return Err(Error::HandshakeFailed);
};
transcript.commit(secret.as_ref());
let recv = RecvCipher::new(transcript.noise(LABEL_CIPHER_L2D));
let send = SendCipher::new(transcript.noise(LABEL_CIPHER_D2L));
let confirmation_l2d = transcript.fork(LABEL_CONFIRMATION_L2D).summarize();
let confirmation_d2l = transcript.fork(LABEL_CONFIRMATION_D2L).summarize();
if msg.confirmation != confirmation_l2d {
return Err(Error::HandshakeFailed);
}
Ok((
Ack {
confirmation: confirmation_d2l,
},
send,
recv,
))
}
pub fn listen_start<S: Signer, P: PublicKey>(
rng: &mut impl CryptoRngCore,
ctx: Context<S, P>,
msg: Syn<<P as Verifier>::Signature>,
) -> Result<(ListenState, SynAck<<S as Signer>::Signature>), Error> {
let Context {
current_time,
my_identity,
peer_identity,
ok_timestamps,
} = ctx;
if !ok_timestamps.contains(&msg.time_ms) {
return Err(Error::InvalidTimestamp(msg.time_ms, ok_timestamps));
}
let mut transcript = Transcript::new(NAMESPACE);
if !transcript
.commit(msg.time_ms.encode())
.commit(my_identity.public_key().encode())
.commit(msg.epk.encode())
.verify(&peer_identity, &msg.sig)
{
return Err(Error::HandshakeFailed);
}
let esk = SecretKey::new(rng);
let epk = esk.public();
let sig = transcript
.commit(peer_identity.encode())
.commit(current_time.encode())
.commit(epk.encode())
.sign(&my_identity);
let Some(secret) = esk.exchange(&msg.epk) else {
return Err(Error::HandshakeFailed);
};
transcript.commit(secret.as_ref());
let send = SendCipher::new(transcript.noise(LABEL_CIPHER_L2D));
let recv = RecvCipher::new(transcript.noise(LABEL_CIPHER_D2L));
let confirmation_l2d = transcript.fork(LABEL_CONFIRMATION_L2D).summarize();
let confirmation_d2l = transcript.fork(LABEL_CONFIRMATION_D2L).summarize();
Ok((
ListenState {
confirmation: confirmation_d2l,
send,
recv,
},
SynAck {
time_ms: current_time,
epk,
sig,
confirmation: confirmation_l2d,
},
))
}
pub fn listen_end(state: ListenState, msg: Ack) -> Result<(SendCipher, RecvCipher), Error> {
if msg.confirmation != state.confirmation {
return Err(Error::HandshakeFailed);
}
Ok((state.send, state.recv))
}
#[cfg(test)]
mod test {
use super::*;
use crate::{ed25519::PrivateKey, PrivateKeyExt as _, Signer};
use commonware_codec::{Codec, DecodeExt};
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
fn test_encode_roundtrip<T: Codec<Cfg = ()> + PartialEq>(value: &T) {
assert!(value == &<T as DecodeExt<_>>::decode(value.encode()).unwrap());
}
#[test]
fn test_can_setup_and_send_messages() -> Result<(), Error> {
let mut rng = ChaCha8Rng::seed_from_u64(0);
let dialer_crypto = PrivateKey::from_rng(&mut rng);
let listener_crypto = PrivateKey::from_rng(&mut rng);
let (d_state, msg1) = dial_start(
&mut rng,
Context {
current_time: 0,
ok_timestamps: 0..1,
my_identity: dialer_crypto.clone(),
peer_identity: listener_crypto.public_key(),
},
);
test_encode_roundtrip(&msg1);
let (l_state, msg2) = listen_start(
&mut rng,
Context {
current_time: 0,
ok_timestamps: 0..1,
my_identity: listener_crypto,
peer_identity: dialer_crypto.public_key(),
},
msg1,
)?;
test_encode_roundtrip(&msg2);
let (msg3, mut d_send, mut d_recv) = dial_end(d_state, msg2)?;
test_encode_roundtrip(&msg3);
let (mut l_send, mut l_recv) = listen_end(l_state, msg3)?;
let m1: &'static [u8] = b"message 1";
let c1 = d_send.send(m1)?;
let m1_prime = l_recv.recv(&c1)?;
assert_eq!(m1, &m1_prime);
let m2: &'static [u8] = b"message 2";
let c2 = l_send.send(m2)?;
let m2_prime = d_recv.recv(&c2)?;
assert_eq!(m2, &m2_prime);
Ok(())
}
}