use super::{
CipherState, EPOCH_ENCRYPTED_SIZE, EPOCH_SIZE, HANDSHAKE_MSG1_SIZE, HANDSHAKE_MSG2_SIZE,
HandshakeProgress, HandshakeRole, NoiseError, NoisePattern, NoiseSession, PROTOCOL_NAME_IK,
PROTOCOL_NAME_XK, PUBKEY_SIZE, XK_HANDSHAKE_MSG1_SIZE, XK_HANDSHAKE_MSG2_SIZE,
XK_HANDSHAKE_MSG3_SIZE,
};
use hkdf::Hkdf;
use rand::Rng;
use secp256k1::{Keypair, PublicKey, Secp256k1, SecretKey, ecdh::shared_secret_point};
use sha2::{Digest, Sha256};
use std::fmt;
struct SymmetricState {
ck: [u8; 32],
h: [u8; 32],
cipher: CipherState,
}
impl SymmetricState {
fn initialize(protocol_name: &[u8]) -> Self {
let h = if protocol_name.len() <= 32 {
let mut h = [0u8; 32];
h[..protocol_name.len()].copy_from_slice(protocol_name);
h
} else {
let mut hasher = Sha256::new();
hasher.update(protocol_name);
hasher.finalize().into()
};
Self {
ck: h,
h,
cipher: CipherState::empty(),
}
}
fn mix_hash(&mut self, data: &[u8]) {
let mut hasher = Sha256::new();
hasher.update(self.h);
hasher.update(data);
self.h = hasher.finalize().into();
}
fn mix_key(&mut self, input_key_material: &[u8]) {
let hk = Hkdf::<Sha256>::new(Some(&self.ck), input_key_material);
let mut output = [0u8; 64];
hk.expand(&[], &mut output)
.expect("64 bytes is valid output length");
self.ck.copy_from_slice(&output[..32]);
let mut key = [0u8; 32];
key.copy_from_slice(&output[32..64]);
self.cipher.initialize_key(key);
}
fn encrypt_and_hash(&mut self, plaintext: &[u8]) -> Result<Vec<u8>, NoiseError> {
let ciphertext = self.cipher.encrypt(plaintext)?;
self.mix_hash(&ciphertext);
Ok(ciphertext)
}
fn decrypt_and_hash(&mut self, ciphertext: &[u8]) -> Result<Vec<u8>, NoiseError> {
let plaintext = self.cipher.decrypt(ciphertext)?;
self.mix_hash(ciphertext);
Ok(plaintext)
}
fn split(&self) -> (CipherState, CipherState) {
let hk = Hkdf::<Sha256>::new(Some(&self.ck), &[]);
let mut output = [0u8; 64];
hk.expand(&[], &mut output)
.expect("64 bytes is valid output length");
let mut k1 = [0u8; 32];
let mut k2 = [0u8; 32];
k1.copy_from_slice(&output[..32]);
k2.copy_from_slice(&output[32..64]);
(CipherState::new(k1), CipherState::new(k2))
}
fn handshake_hash(&self) -> [u8; 32] {
self.h
}
}
pub struct HandshakeState {
pattern: NoisePattern,
role: HandshakeRole,
progress: HandshakeProgress,
symmetric: SymmetricState,
static_keypair: Keypair,
ephemeral_keypair: Option<Keypair>,
remote_static: Option<PublicKey>,
remote_ephemeral: Option<PublicKey>,
secp: Secp256k1<secp256k1::All>,
local_epoch: Option<[u8; 8]>,
remote_epoch: Option<[u8; 8]>,
}
impl HandshakeState {
fn normalize_for_premessage(pubkey: &PublicKey) -> [u8; PUBKEY_SIZE] {
let mut bytes = pubkey.serialize();
bytes[0] = 0x02; bytes
}
pub fn new_initiator(static_keypair: Keypair, remote_static: PublicKey) -> Self {
let secp = Secp256k1::new();
let mut state = Self {
pattern: NoisePattern::Ik,
role: HandshakeRole::Initiator,
progress: HandshakeProgress::Initial,
symmetric: SymmetricState::initialize(PROTOCOL_NAME_IK),
static_keypair,
ephemeral_keypair: None,
remote_static: Some(remote_static),
remote_ephemeral: None,
secp,
local_epoch: None,
remote_epoch: None,
};
let normalized = Self::normalize_for_premessage(&remote_static);
state.symmetric.mix_hash(&normalized);
state
}
pub fn new_responder(static_keypair: Keypair) -> Self {
let secp = Secp256k1::new();
let mut state = Self {
pattern: NoisePattern::Ik,
role: HandshakeRole::Responder,
progress: HandshakeProgress::Initial,
symmetric: SymmetricState::initialize(PROTOCOL_NAME_IK),
static_keypair,
ephemeral_keypair: None,
remote_static: None, remote_ephemeral: None,
secp,
local_epoch: None,
remote_epoch: None,
};
let normalized = Self::normalize_for_premessage(&state.static_keypair.public_key());
state.symmetric.mix_hash(&normalized);
state
}
pub fn new_xk_initiator(static_keypair: Keypair, remote_static: PublicKey) -> Self {
let secp = Secp256k1::new();
let mut state = Self {
pattern: NoisePattern::Xk,
role: HandshakeRole::Initiator,
progress: HandshakeProgress::Initial,
symmetric: SymmetricState::initialize(PROTOCOL_NAME_XK),
static_keypair,
ephemeral_keypair: None,
remote_static: Some(remote_static),
remote_ephemeral: None,
secp,
local_epoch: None,
remote_epoch: None,
};
let normalized = Self::normalize_for_premessage(&remote_static);
state.symmetric.mix_hash(&normalized);
state
}
pub fn new_xk_responder(static_keypair: Keypair) -> Self {
let secp = Secp256k1::new();
let mut state = Self {
pattern: NoisePattern::Xk,
role: HandshakeRole::Responder,
progress: HandshakeProgress::Initial,
symmetric: SymmetricState::initialize(PROTOCOL_NAME_XK),
static_keypair,
ephemeral_keypair: None,
remote_static: None, remote_ephemeral: None,
secp,
local_epoch: None,
remote_epoch: None,
};
let normalized = Self::normalize_for_premessage(&state.static_keypair.public_key());
state.symmetric.mix_hash(&normalized);
state
}
pub fn role(&self) -> HandshakeRole {
self.role
}
pub fn progress(&self) -> HandshakeProgress {
self.progress
}
pub fn is_complete(&self) -> bool {
self.progress == HandshakeProgress::Complete
}
pub fn remote_static(&self) -> Option<&PublicKey> {
self.remote_static.as_ref()
}
pub fn set_local_epoch(&mut self, epoch: [u8; 8]) {
self.local_epoch = Some(epoch);
}
pub fn remote_epoch(&self) -> Option<[u8; 8]> {
self.remote_epoch
}
fn generate_ephemeral(&mut self) {
let mut rng = rand::rng();
let mut secret_bytes = [0u8; 32];
rng.fill_bytes(&mut secret_bytes);
let secret_key =
SecretKey::from_slice(&secret_bytes).expect("32 random bytes is valid secret key");
self.ephemeral_keypair = Some(Keypair::from_secret_key(&self.secp, &secret_key));
}
fn ecdh(&self, our_secret: &SecretKey, their_public: &PublicKey) -> [u8; 32] {
let point = shared_secret_point(their_public, our_secret);
let mut hasher = Sha256::new();
hasher.update(&point[..32]);
let hash = hasher.finalize();
let mut result = [0u8; 32];
result.copy_from_slice(&hash);
result
}
pub fn write_message_1(&mut self) -> Result<Vec<u8>, NoiseError> {
if self.role != HandshakeRole::Initiator {
return Err(NoiseError::WrongState {
expected: "initiator".to_string(),
got: "responder".to_string(),
});
}
if self.progress != HandshakeProgress::Initial {
return Err(NoiseError::WrongState {
expected: HandshakeProgress::Initial.to_string(),
got: self.progress.to_string(),
});
}
let remote_static = self
.remote_static
.expect("initiator must have remote static");
let epoch = self
.local_epoch
.expect("local epoch must be set before write_message_1");
self.generate_ephemeral();
let ephemeral = self.ephemeral_keypair.as_ref().unwrap();
let e_pub = ephemeral.public_key().serialize();
let mut message = Vec::with_capacity(HANDSHAKE_MSG1_SIZE);
message.extend_from_slice(&e_pub);
self.symmetric.mix_hash(&e_pub);
let es = self.ecdh(&ephemeral.secret_key(), &remote_static);
self.symmetric.mix_key(&es);
let our_static = self.static_keypair.public_key().serialize();
let encrypted_static = self.symmetric.encrypt_and_hash(&our_static)?;
message.extend_from_slice(&encrypted_static);
let ss = self.ecdh(&self.static_keypair.secret_key(), &remote_static);
self.symmetric.mix_key(&ss);
let encrypted_epoch = self.symmetric.encrypt_and_hash(&epoch)?;
debug_assert_eq!(encrypted_epoch.len(), EPOCH_ENCRYPTED_SIZE);
message.extend_from_slice(&encrypted_epoch);
self.progress = HandshakeProgress::Message1Done;
Ok(message)
}
pub fn read_message_1(&mut self, message: &[u8]) -> Result<(), NoiseError> {
if self.role != HandshakeRole::Responder {
return Err(NoiseError::WrongState {
expected: "responder".to_string(),
got: "initiator".to_string(),
});
}
if self.progress != HandshakeProgress::Initial {
return Err(NoiseError::WrongState {
expected: HandshakeProgress::Initial.to_string(),
got: self.progress.to_string(),
});
}
if message.len() != HANDSHAKE_MSG1_SIZE {
return Err(NoiseError::MessageTooShort {
expected: HANDSHAKE_MSG1_SIZE,
got: message.len(),
});
}
let re = PublicKey::from_slice(&message[..PUBKEY_SIZE])
.map_err(|_| NoiseError::InvalidPublicKey)?;
self.remote_ephemeral = Some(re);
self.symmetric.mix_hash(&message[..PUBKEY_SIZE]);
let es = self.ecdh(&self.static_keypair.secret_key(), &re);
self.symmetric.mix_key(&es);
let encrypted_static_end = PUBKEY_SIZE + PUBKEY_SIZE + super::TAG_SIZE;
let encrypted_static = &message[PUBKEY_SIZE..encrypted_static_end];
let decrypted_static = self.symmetric.decrypt_and_hash(encrypted_static)?;
let rs =
PublicKey::from_slice(&decrypted_static).map_err(|_| NoiseError::InvalidPublicKey)?;
self.remote_static = Some(rs);
let ss = self.ecdh(&self.static_keypair.secret_key(), &rs);
self.symmetric.mix_key(&ss);
let encrypted_epoch = &message[encrypted_static_end..];
debug_assert_eq!(encrypted_epoch.len(), EPOCH_ENCRYPTED_SIZE);
let decrypted_epoch = self.symmetric.decrypt_and_hash(encrypted_epoch)?;
debug_assert_eq!(decrypted_epoch.len(), EPOCH_SIZE);
let mut epoch = [0u8; EPOCH_SIZE];
epoch.copy_from_slice(&decrypted_epoch);
self.remote_epoch = Some(epoch);
self.progress = HandshakeProgress::Message1Done;
Ok(())
}
pub fn write_message_2(&mut self) -> Result<Vec<u8>, NoiseError> {
if self.role != HandshakeRole::Responder {
return Err(NoiseError::WrongState {
expected: "responder".to_string(),
got: "initiator".to_string(),
});
}
if self.progress != HandshakeProgress::Message1Done {
return Err(NoiseError::WrongState {
expected: HandshakeProgress::Message1Done.to_string(),
got: self.progress.to_string(),
});
}
let re = self.remote_ephemeral.expect("should have remote ephemeral");
let epoch = self
.local_epoch
.expect("local epoch must be set before write_message_2");
self.generate_ephemeral();
let ephemeral = self.ephemeral_keypair.as_ref().unwrap();
let e_pub = ephemeral.public_key().serialize();
let mut message = Vec::with_capacity(HANDSHAKE_MSG2_SIZE);
message.extend_from_slice(&e_pub);
self.symmetric.mix_hash(&e_pub);
let ee = self.ecdh(&ephemeral.secret_key(), &re);
self.symmetric.mix_key(&ee);
let se = self.ecdh(&self.static_keypair.secret_key(), &re);
self.symmetric.mix_key(&se);
let encrypted_epoch = self.symmetric.encrypt_and_hash(&epoch)?;
debug_assert_eq!(encrypted_epoch.len(), EPOCH_ENCRYPTED_SIZE);
message.extend_from_slice(&encrypted_epoch);
self.progress = HandshakeProgress::Complete;
Ok(message)
}
pub fn read_message_2(&mut self, message: &[u8]) -> Result<(), NoiseError> {
if self.role != HandshakeRole::Initiator {
return Err(NoiseError::WrongState {
expected: "initiator".to_string(),
got: "responder".to_string(),
});
}
if self.progress != HandshakeProgress::Message1Done {
return Err(NoiseError::WrongState {
expected: HandshakeProgress::Message1Done.to_string(),
got: self.progress.to_string(),
});
}
if message.len() != HANDSHAKE_MSG2_SIZE {
return Err(NoiseError::MessageTooShort {
expected: HANDSHAKE_MSG2_SIZE,
got: message.len(),
});
}
let e_pub = &message[..PUBKEY_SIZE];
let re = PublicKey::from_slice(e_pub).map_err(|_| NoiseError::InvalidPublicKey)?;
self.remote_ephemeral = Some(re);
self.symmetric.mix_hash(e_pub);
let ephemeral = self.ephemeral_keypair.as_ref().unwrap();
let ee = self.ecdh(&ephemeral.secret_key(), &re);
self.symmetric.mix_key(&ee);
let rs = self.remote_static.expect("initiator has remote static");
let se = self.ecdh(&ephemeral.secret_key(), &rs);
self.symmetric.mix_key(&se);
let encrypted_epoch = &message[PUBKEY_SIZE..];
debug_assert_eq!(encrypted_epoch.len(), EPOCH_ENCRYPTED_SIZE);
let decrypted_epoch = self.symmetric.decrypt_and_hash(encrypted_epoch)?;
debug_assert_eq!(decrypted_epoch.len(), EPOCH_SIZE);
let mut epoch = [0u8; EPOCH_SIZE];
epoch.copy_from_slice(&decrypted_epoch);
self.remote_epoch = Some(epoch);
self.progress = HandshakeProgress::Complete;
Ok(())
}
pub fn write_xk_message_1(&mut self) -> Result<Vec<u8>, NoiseError> {
if self.role != HandshakeRole::Initiator {
return Err(NoiseError::WrongState {
expected: "initiator".to_string(),
got: "responder".to_string(),
});
}
if self.progress != HandshakeProgress::Initial {
return Err(NoiseError::WrongState {
expected: HandshakeProgress::Initial.to_string(),
got: self.progress.to_string(),
});
}
let remote_static = self
.remote_static
.expect("initiator must have remote static");
self.generate_ephemeral();
let ephemeral = self.ephemeral_keypair.as_ref().unwrap();
let e_pub = ephemeral.public_key().serialize();
let mut message = Vec::with_capacity(XK_HANDSHAKE_MSG1_SIZE);
message.extend_from_slice(&e_pub);
self.symmetric.mix_hash(&e_pub);
let es = self.ecdh(&ephemeral.secret_key(), &remote_static);
self.symmetric.mix_key(&es);
self.progress = HandshakeProgress::Message1Done;
Ok(message)
}
pub fn read_xk_message_1(&mut self, message: &[u8]) -> Result<(), NoiseError> {
if self.role != HandshakeRole::Responder {
return Err(NoiseError::WrongState {
expected: "responder".to_string(),
got: "initiator".to_string(),
});
}
if self.progress != HandshakeProgress::Initial {
return Err(NoiseError::WrongState {
expected: HandshakeProgress::Initial.to_string(),
got: self.progress.to_string(),
});
}
if message.len() != XK_HANDSHAKE_MSG1_SIZE {
return Err(NoiseError::MessageTooShort {
expected: XK_HANDSHAKE_MSG1_SIZE,
got: message.len(),
});
}
let re = PublicKey::from_slice(&message[..PUBKEY_SIZE])
.map_err(|_| NoiseError::InvalidPublicKey)?;
self.remote_ephemeral = Some(re);
self.symmetric.mix_hash(&message[..PUBKEY_SIZE]);
let es = self.ecdh(&self.static_keypair.secret_key(), &re);
self.symmetric.mix_key(&es);
self.progress = HandshakeProgress::Message1Done;
Ok(())
}
pub fn write_xk_message_2(&mut self) -> Result<Vec<u8>, NoiseError> {
if self.role != HandshakeRole::Responder {
return Err(NoiseError::WrongState {
expected: "responder".to_string(),
got: "initiator".to_string(),
});
}
if self.progress != HandshakeProgress::Message1Done {
return Err(NoiseError::WrongState {
expected: HandshakeProgress::Message1Done.to_string(),
got: self.progress.to_string(),
});
}
let re = self.remote_ephemeral.expect("should have remote ephemeral");
let epoch = self
.local_epoch
.expect("local epoch must be set before write_xk_message_2");
self.generate_ephemeral();
let ephemeral = self.ephemeral_keypair.as_ref().unwrap();
let e_pub = ephemeral.public_key().serialize();
let mut message = Vec::with_capacity(XK_HANDSHAKE_MSG2_SIZE);
message.extend_from_slice(&e_pub);
self.symmetric.mix_hash(&e_pub);
let ee = self.ecdh(&ephemeral.secret_key(), &re);
self.symmetric.mix_key(&ee);
let encrypted_epoch = self.symmetric.encrypt_and_hash(&epoch)?;
debug_assert_eq!(encrypted_epoch.len(), EPOCH_ENCRYPTED_SIZE);
message.extend_from_slice(&encrypted_epoch);
self.progress = HandshakeProgress::Message2Done;
Ok(message)
}
pub fn read_xk_message_2(&mut self, message: &[u8]) -> Result<(), NoiseError> {
if self.role != HandshakeRole::Initiator {
return Err(NoiseError::WrongState {
expected: "initiator".to_string(),
got: "responder".to_string(),
});
}
if self.progress != HandshakeProgress::Message1Done {
return Err(NoiseError::WrongState {
expected: HandshakeProgress::Message1Done.to_string(),
got: self.progress.to_string(),
});
}
if message.len() != XK_HANDSHAKE_MSG2_SIZE {
return Err(NoiseError::MessageTooShort {
expected: XK_HANDSHAKE_MSG2_SIZE,
got: message.len(),
});
}
let e_pub = &message[..PUBKEY_SIZE];
let re = PublicKey::from_slice(e_pub).map_err(|_| NoiseError::InvalidPublicKey)?;
self.remote_ephemeral = Some(re);
self.symmetric.mix_hash(e_pub);
let ephemeral = self.ephemeral_keypair.as_ref().unwrap();
let ee = self.ecdh(&ephemeral.secret_key(), &re);
self.symmetric.mix_key(&ee);
let encrypted_epoch = &message[PUBKEY_SIZE..];
debug_assert_eq!(encrypted_epoch.len(), EPOCH_ENCRYPTED_SIZE);
let decrypted_epoch = self.symmetric.decrypt_and_hash(encrypted_epoch)?;
debug_assert_eq!(decrypted_epoch.len(), EPOCH_SIZE);
let mut epoch = [0u8; EPOCH_SIZE];
epoch.copy_from_slice(&decrypted_epoch);
self.remote_epoch = Some(epoch);
self.progress = HandshakeProgress::Message2Done;
Ok(())
}
pub fn write_xk_message_3(&mut self) -> Result<Vec<u8>, NoiseError> {
if self.role != HandshakeRole::Initiator {
return Err(NoiseError::WrongState {
expected: "initiator".to_string(),
got: "responder".to_string(),
});
}
if self.progress != HandshakeProgress::Message2Done {
return Err(NoiseError::WrongState {
expected: HandshakeProgress::Message2Done.to_string(),
got: self.progress.to_string(),
});
}
let re = self
.remote_ephemeral
.expect("should have remote ephemeral after msg2");
let epoch = self
.local_epoch
.expect("local epoch must be set before write_xk_message_3");
let mut message = Vec::with_capacity(XK_HANDSHAKE_MSG3_SIZE);
let our_static = self.static_keypair.public_key().serialize();
let encrypted_static = self.symmetric.encrypt_and_hash(&our_static)?;
message.extend_from_slice(&encrypted_static);
let se = self.ecdh(&self.static_keypair.secret_key(), &re);
self.symmetric.mix_key(&se);
let encrypted_epoch = self.symmetric.encrypt_and_hash(&epoch)?;
debug_assert_eq!(encrypted_epoch.len(), EPOCH_ENCRYPTED_SIZE);
message.extend_from_slice(&encrypted_epoch);
self.progress = HandshakeProgress::Complete;
Ok(message)
}
pub fn read_xk_message_3(&mut self, message: &[u8]) -> Result<(), NoiseError> {
if self.role != HandshakeRole::Responder {
return Err(NoiseError::WrongState {
expected: "responder".to_string(),
got: "initiator".to_string(),
});
}
if self.progress != HandshakeProgress::Message2Done {
return Err(NoiseError::WrongState {
expected: HandshakeProgress::Message2Done.to_string(),
got: self.progress.to_string(),
});
}
if message.len() != XK_HANDSHAKE_MSG3_SIZE {
return Err(NoiseError::MessageTooShort {
expected: XK_HANDSHAKE_MSG3_SIZE,
got: message.len(),
});
}
let encrypted_static_end = PUBKEY_SIZE + super::TAG_SIZE;
let encrypted_static = &message[..encrypted_static_end];
let decrypted_static = self.symmetric.decrypt_and_hash(encrypted_static)?;
let rs =
PublicKey::from_slice(&decrypted_static).map_err(|_| NoiseError::InvalidPublicKey)?;
self.remote_static = Some(rs);
let ephemeral = self
.ephemeral_keypair
.as_ref()
.expect("should have ephemeral after msg2");
let se = self.ecdh(&ephemeral.secret_key(), &rs);
self.symmetric.mix_key(&se);
let encrypted_epoch = &message[encrypted_static_end..];
debug_assert_eq!(encrypted_epoch.len(), EPOCH_ENCRYPTED_SIZE);
let decrypted_epoch = self.symmetric.decrypt_and_hash(encrypted_epoch)?;
debug_assert_eq!(decrypted_epoch.len(), EPOCH_SIZE);
let mut epoch = [0u8; EPOCH_SIZE];
epoch.copy_from_slice(&decrypted_epoch);
self.remote_epoch = Some(epoch);
self.progress = HandshakeProgress::Complete;
Ok(())
}
pub fn into_session(self) -> Result<NoiseSession, NoiseError> {
if !self.is_complete() {
return Err(NoiseError::HandshakeNotComplete);
}
let (c1, c2) = self.symmetric.split();
let handshake_hash = self.symmetric.handshake_hash();
let remote_static = self
.remote_static
.expect("remote static must be known after handshake");
let (send_cipher, recv_cipher) = match self.role {
HandshakeRole::Initiator => (c1, c2),
HandshakeRole::Responder => (c2, c1),
};
Ok(NoiseSession::from_handshake(
self.role,
send_cipher,
recv_cipher,
handshake_hash,
remote_static,
))
}
pub fn handshake_hash(&self) -> [u8; 32] {
self.symmetric.handshake_hash()
}
}
impl fmt::Debug for HandshakeState {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("HandshakeState")
.field("pattern", &self.pattern)
.field("role", &self.role)
.field("progress", &self.progress)
.field("has_ephemeral", &self.ephemeral_keypair.is_some())
.field("has_remote_static", &self.remote_static.is_some())
.field("has_remote_ephemeral", &self.remote_ephemeral.is_some())
.field("has_local_epoch", &self.local_epoch.is_some())
.field("has_remote_epoch", &self.remote_epoch.is_some())
.finish()
}
}