use crate::{Error, Result};
use sha2::{Sha256, Digest};
use x25519_dalek::{EphemeralSecret, PublicKey, SharedSecret};
use rand::rngs::OsRng;
use zeroize::{Zeroize, ZeroizeOnDrop};
pub const NOISE_PROTOCOL: &str = "Noise_XX_25519_AESGCM_SHA256";
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HandshakeState {
Initial,
WaitingForResponse,
Complete,
Failed,
}
pub struct NoiseKeyPair {
secret: EphemeralSecret,
public: PublicKey,
}
impl NoiseKeyPair {
pub fn generate() -> Self {
let secret = EphemeralSecret::random_from_rng(OsRng);
let public = PublicKey::from(&secret);
Self { secret, public }
}
pub fn public_key(&self) -> &PublicKey {
&self.public
}
pub fn public_key_bytes(&self) -> [u8; 32] {
*self.public.as_bytes()
}
pub fn dh(self, their_public: &PublicKey) -> SharedSecret {
self.secret.diffie_hellman(their_public)
}
}
#[derive(Zeroize, ZeroizeOnDrop)]
pub struct CipherState {
key: [u8; 32],
nonce: u64,
}
impl CipherState {
pub fn new(key: [u8; 32]) -> Self {
Self { key, nonce: 0 }
}
pub fn encrypt(&mut self, plaintext: &[u8], ad: &[u8]) -> Result<Vec<u8>> {
use crate::crypto::aead;
let mut nonce_bytes = [0u8; 12];
nonce_bytes[4..].copy_from_slice(&self.nonce.to_le_bytes());
self.nonce += 1;
aead::encrypt(plaintext, &self.key, ad)
}
pub fn decrypt(&mut self, ciphertext: &[u8], ad: &[u8]) -> Result<Vec<u8>> {
use crate::crypto::aead;
aead::decrypt(ciphertext, &self.key, ad)
}
}
pub struct SymmetricState {
ck: [u8; 32],
h: [u8; 32],
cipher: Option<CipherState>,
}
impl SymmetricState {
pub fn new(protocol_name: &str) -> Self {
let mut h = [0u8; 32];
if protocol_name.len() <= 32 {
h[..protocol_name.len()].copy_from_slice(protocol_name.as_bytes());
} else {
h = Sha256::digest(protocol_name.as_bytes()).into();
}
Self {
ck: h,
h,
cipher: None,
}
}
pub fn mix_hash(&mut self, data: &[u8]) {
let mut hasher = Sha256::new();
hasher.update(&self.h);
hasher.update(data);
self.h = hasher.finalize().into();
}
pub fn mix_key(&mut self, input_key_material: &[u8]) {
use hkdf::Hkdf;
let hk = Hkdf::<Sha256>::new(Some(&self.ck), input_key_material);
let mut output = [0u8; 64];
hk.expand(b"", &mut output).unwrap();
self.ck.copy_from_slice(&output[..32]);
let mut temp_k = [0u8; 32];
temp_k.copy_from_slice(&output[32..64]);
self.cipher = Some(CipherState::new(temp_k));
}
pub fn handshake_hash(&self) -> &[u8; 32] {
&self.h
}
pub fn split(self) -> Result<(CipherState, CipherState)> {
use hkdf::Hkdf;
let hk = Hkdf::<Sha256>::new(Some(&self.ck), &[]);
let mut output = [0u8; 64];
hk.expand(b"", &mut output).unwrap();
let mut send_key = [0u8; 32];
let mut recv_key = [0u8; 32];
send_key.copy_from_slice(&output[..32]);
recv_key.copy_from_slice(&output[32..64]);
Ok((CipherState::new(send_key), CipherState::new(recv_key)))
}
}
pub struct NoiseSession {
static_keypair: Option<NoiseKeyPair>,
ephemeral_keypair: Option<NoiseKeyPair>,
their_static: Option<PublicKey>,
their_ephemeral: Option<PublicKey>,
symmetric: SymmetricState,
state: HandshakeState,
initiator: bool,
}
impl NoiseSession {
pub fn initiator(static_keypair: NoiseKeyPair) -> Self {
let mut symmetric = SymmetricState::new(NOISE_PROTOCOL);
symmetric.mix_hash(static_keypair.public_key().as_bytes());
Self {
static_keypair: Some(static_keypair),
ephemeral_keypair: None,
their_static: None,
their_ephemeral: None,
symmetric,
state: HandshakeState::Initial,
initiator: true,
}
}
pub fn responder(static_keypair: NoiseKeyPair) -> Self {
let symmetric = SymmetricState::new(NOISE_PROTOCOL);
Self {
static_keypair: Some(static_keypair),
ephemeral_keypair: None,
their_static: None,
their_ephemeral: None,
symmetric,
state: HandshakeState::Initial,
initiator: false,
}
}
pub fn state(&self) -> HandshakeState {
self.state
}
pub fn is_complete(&self) -> bool {
self.state == HandshakeState::Complete
}
pub fn write_message(&mut self, payload: &[u8]) -> Result<Vec<u8>> {
match self.state {
HandshakeState::Initial if self.initiator => {
let ephemeral = NoiseKeyPair::generate();
let e_pub = ephemeral.public_key_bytes();
self.symmetric.mix_hash(&e_pub);
let mut message = e_pub.to_vec();
message.extend_from_slice(payload);
self.ephemeral_keypair = Some(ephemeral);
self.state = HandshakeState::WaitingForResponse;
Ok(message)
}
HandshakeState::WaitingForResponse if !self.initiator => {
let ephemeral = NoiseKeyPair::generate();
let e_pub = ephemeral.public_key_bytes();
self.symmetric.mix_hash(&e_pub);
self.state = HandshakeState::Complete;
Ok(e_pub.to_vec())
}
HandshakeState::Complete => {
Err(Error::KeyExchange("Handshake already complete".into()))
}
_ => Err(Error::KeyExchange("Invalid handshake state".into())),
}
}
pub fn read_message(&mut self, message: &[u8]) -> Result<Vec<u8>> {
if message.len() < 32 {
return Err(Error::InvalidMessage("Message too short".into()));
}
match self.state {
HandshakeState::Initial if !self.initiator => {
let mut e_pub_bytes = [0u8; 32];
e_pub_bytes.copy_from_slice(&message[..32]);
let their_ephemeral = PublicKey::from(e_pub_bytes);
self.symmetric.mix_hash(&e_pub_bytes);
self.their_ephemeral = Some(their_ephemeral);
self.state = HandshakeState::WaitingForResponse;
let payload = message[32..].to_vec();
Ok(payload)
}
HandshakeState::WaitingForResponse if self.initiator => {
let mut e_pub_bytes = [0u8; 32];
e_pub_bytes.copy_from_slice(&message[..32]);
let their_ephemeral = PublicKey::from(e_pub_bytes);
self.symmetric.mix_hash(&e_pub_bytes);
self.their_ephemeral = Some(their_ephemeral);
self.state = HandshakeState::Complete;
Ok(Vec::new())
}
HandshakeState::Complete => {
Err(Error::KeyExchange("Handshake already complete".into()))
}
_ => Err(Error::InvalidMessage("Unexpected message".into())),
}
}
pub fn into_transport(self) -> Result<(CipherState, CipherState)> {
if self.state != HandshakeState::Complete {
return Err(Error::KeyExchange("Handshake not complete".into()));
}
self.symmetric.split()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_noise_keypair() {
let kp = NoiseKeyPair::generate();
assert_eq!(kp.public_key_bytes().len(), 32);
}
#[test]
fn test_symmetric_state() {
let state = SymmetricState::new(NOISE_PROTOCOL);
assert_eq!(state.handshake_hash().len(), 32);
}
#[test]
fn test_handshake_initiator_first_message() {
let kp = NoiseKeyPair::generate();
let mut session = NoiseSession::initiator(kp);
let msg = session.write_message(b"hello").unwrap();
assert!(msg.len() >= 32); assert_eq!(session.state(), HandshakeState::WaitingForResponse);
}
}