#![allow(clippy::integer_arithmetic)]
use crate::{hash::HashValue, hkdf::Hkdf, traits::Uniform as _, x25519};
use aes_gcm::{
aead::{generic_array::GenericArray, Aead, AeadInPlace, NewAead, Payload},
Aes256Gcm,
};
use sha2::Digest;
use std::{
convert::TryFrom as _,
io::{Cursor, Read as _, Write as _},
};
use thiserror::Error;
pub const MAX_SIZE_NOISE_MSG: usize = 65535;
pub const AES_GCM_TAGLEN: usize = 16;
const PROTOCOL_NAME: &[u8] = b"Noise_IK_25519_AESGCM_SHA256\0\0\0\0";
const AES_NONCE_SIZE: usize = 12;
pub const fn encrypted_len(plaintext_len: usize) -> usize {
plaintext_len + AES_GCM_TAGLEN
}
pub const fn decrypted_len(ciphertext_len: usize) -> usize {
ciphertext_len - AES_GCM_TAGLEN
}
pub const fn handshake_init_msg_len(payload_len: usize) -> usize {
let e_len = x25519::PUBLIC_KEY_SIZE;
let enc_s_len = encrypted_len(x25519::PUBLIC_KEY_SIZE);
let enc_payload_len = encrypted_len(payload_len);
e_len + enc_s_len + enc_payload_len
}
pub const fn handshake_resp_msg_len(payload_len: usize) -> usize {
let e_len = x25519::PUBLIC_KEY_SIZE;
let enc_payload_len = encrypted_len(payload_len);
e_len + enc_payload_len
}
#[rustfmt::skip]
const _: [(); 32] = [(); HashValue::LENGTH];
#[derive(Debug, Error)]
pub enum NoiseError {
#[error("noise: the received message is too short to contain the expected data")]
MsgTooShort,
#[error("noise: HKDF has failed")]
Hkdf,
#[error("noise: encryption has failed")]
Encrypt,
#[error("noise: could not decrypt the received data")]
Decrypt,
#[error("noise: the public key received is of the wrong format")]
WrongPublicKeyReceived,
#[error("noise: session was closed due to decrypt error")]
SessionClosed,
#[error("noise: the payload that we are trying to send is too large")]
PayloadTooLarge,
#[error("noise: the message we received is too large")]
ReceivedMsgTooLarge,
#[error("noise: the response buffer passed as argument is too small")]
ResponseBufferTooSmall,
#[error("noise: the nonce exceeds the maximum u64 value")]
NonceOverflow,
}
fn hash(data: &[u8]) -> Vec<u8> {
sha2::Sha256::digest(data).to_vec()
}
fn hkdf(ck: &[u8], dh_output: Option<&[u8]>) -> Result<(Vec<u8>, Vec<u8>), NoiseError> {
let dh_output = dh_output.unwrap_or(&[]);
let hkdf_output = if dh_output.is_empty() {
Hkdf::<sha2::Sha256>::extract_then_expand_no_ikm(Some(ck), None, 64)
} else {
Hkdf::<sha2::Sha256>::extract_then_expand(Some(ck), dh_output, None, 64)
};
let hkdf_output = hkdf_output.map_err(|_| NoiseError::Hkdf)?;
let (k1, k2) = hkdf_output.split_at(32);
Ok((k1.to_vec(), k2.to_vec()))
}
fn mix_hash(h: &mut Vec<u8>, data: &[u8]) {
h.extend_from_slice(data);
*h = hash(h);
}
fn mix_key(ck: &mut Vec<u8>, dh_output: &[u8]) -> Result<Vec<u8>, NoiseError> {
let (new_ck, k) = hkdf(ck, Some(dh_output))?;
*ck = new_ck;
Ok(k)
}
#[derive(Debug)]
pub struct NoiseConfig {
private_key: x25519::PrivateKey,
public_key: x25519::PublicKey,
}
#[cfg_attr(test, derive(Clone))]
pub struct InitiatorHandshakeState {
h: Vec<u8>,
ck: Vec<u8>,
e: x25519::PrivateKey,
rs: x25519::PublicKey,
}
#[cfg_attr(test, derive(Clone))]
pub struct ResponderHandshakeState {
h: Vec<u8>,
ck: Vec<u8>,
rs: x25519::PublicKey,
re: x25519::PublicKey,
}
impl NoiseConfig {
pub fn new(private_key: x25519::PrivateKey) -> Self {
let public_key = private_key.public_key();
Self {
private_key,
public_key,
}
}
pub fn public_key(&self) -> x25519::PublicKey {
self.public_key
}
pub fn initiate_connection(
&self,
rng: &mut (impl rand::RngCore + rand::CryptoRng),
prologue: &[u8],
remote_public: x25519::PublicKey,
payload: Option<&[u8]>,
response_buffer: &mut [u8],
) -> Result<InitiatorHandshakeState, NoiseError> {
let payload_len = payload.map(<[u8]>::len).unwrap_or(0);
let buffer_size_required = handshake_init_msg_len(payload_len);
if buffer_size_required > MAX_SIZE_NOISE_MSG {
return Err(NoiseError::PayloadTooLarge);
}
if response_buffer.len() < buffer_size_required {
return Err(NoiseError::ResponseBufferTooSmall);
}
let mut h = PROTOCOL_NAME.to_vec();
let mut ck = PROTOCOL_NAME.to_vec();
let rs = remote_public; mix_hash(&mut h, prologue);
mix_hash(&mut h, rs.as_slice());
let e = x25519::PrivateKey::generate(rng);
let e_pub = e.public_key();
mix_hash(&mut h, e_pub.as_slice());
let mut response_buffer = Cursor::new(response_buffer);
response_buffer
.write(e_pub.as_slice())
.map_err(|_| NoiseError::ResponseBufferTooSmall)?;
let dh_output = e.diffie_hellman(&rs);
let k = mix_key(&mut ck, &dh_output)?;
let aead = Aes256Gcm::new(GenericArray::from_slice(&k));
let msg_and_ad = Payload {
msg: self.public_key.as_slice(),
aad: &h,
};
let nonce = GenericArray::from_slice(&[0u8; AES_NONCE_SIZE]);
let encrypted_static = aead
.encrypt(nonce, msg_and_ad)
.map_err(|_| NoiseError::Encrypt)?;
mix_hash(&mut h, &encrypted_static);
response_buffer
.write(&encrypted_static)
.map_err(|_| NoiseError::ResponseBufferTooSmall)?;
let dh_output = self.private_key.diffie_hellman(&rs);
let k = mix_key(&mut ck, &dh_output)?;
let aead = Aes256Gcm::new(GenericArray::from_slice(&k));
let msg_and_ad = Payload {
msg: payload.unwrap_or(&[]),
aad: &h,
};
let nonce = GenericArray::from_slice(&[0u8; AES_NONCE_SIZE]);
let encrypted_payload = aead
.encrypt(nonce, msg_and_ad)
.map_err(|_| NoiseError::Encrypt)?;
mix_hash(&mut h, &encrypted_payload);
response_buffer
.write(&encrypted_payload)
.map_err(|_| NoiseError::ResponseBufferTooSmall)?;
let handshake_state = InitiatorHandshakeState { h, ck, e, rs };
Ok(handshake_state)
}
pub fn finalize_connection(
&self,
handshake_state: InitiatorHandshakeState,
received_message: &[u8],
) -> Result<(Vec<u8>, NoiseSession), NoiseError> {
if received_message.len() > MAX_SIZE_NOISE_MSG {
return Err(NoiseError::ReceivedMsgTooLarge);
}
let InitiatorHandshakeState {
mut h,
mut ck,
e,
rs,
} = handshake_state;
let mut re = [0u8; x25519::PUBLIC_KEY_SIZE];
let mut cursor = Cursor::new(received_message);
cursor
.read_exact(&mut re)
.map_err(|_| NoiseError::MsgTooShort)?;
mix_hash(&mut h, &re);
let re = x25519::PublicKey::from(re);
let dh_output = e.diffie_hellman(&re);
mix_key(&mut ck, &dh_output)?;
let dh_output = self.private_key.diffie_hellman(&re);
let k = mix_key(&mut ck, &dh_output)?;
let offset = cursor.position() as usize;
let received_encrypted_payload = &cursor.into_inner()[offset..];
let aead = Aes256Gcm::new(GenericArray::from_slice(&k));
let nonce = GenericArray::from_slice(&[0u8; AES_NONCE_SIZE]);
let ct_and_ad = Payload {
msg: received_encrypted_payload,
aad: &h,
};
let received_payload = aead
.decrypt(nonce, ct_and_ad)
.map_err(|_| NoiseError::Decrypt)?;
let (k1, k2) = hkdf(&ck, None)?;
let session = NoiseSession::new(k1, k2, rs);
Ok((received_payload, session))
}
pub fn parse_client_init_message(
&self,
prologue: &[u8],
received_message: &[u8],
) -> Result<
(
x25519::PublicKey, // initiator's public key
ResponderHandshakeState, // state to be used in respond_to_client
Vec<u8>, // payload received
),
NoiseError,
> {
if received_message.len() > MAX_SIZE_NOISE_MSG {
return Err(NoiseError::ReceivedMsgTooLarge);
}
let mut h = PROTOCOL_NAME.to_vec();
let mut ck = PROTOCOL_NAME.to_vec();
mix_hash(&mut h, prologue);
mix_hash(&mut h, self.public_key.as_slice());
let mut cursor = Cursor::new(received_message);
let mut re = [0u8; x25519::PUBLIC_KEY_SIZE];
cursor
.read_exact(&mut re)
.map_err(|_| NoiseError::MsgTooShort)?;
mix_hash(&mut h, &re);
let re = x25519::PublicKey::from(re);
let dh_output = self.private_key.diffie_hellman(&re);
let k = mix_key(&mut ck, &dh_output)?;
let mut encrypted_remote_static = [0u8; x25519::PUBLIC_KEY_SIZE + AES_GCM_TAGLEN];
cursor
.read_exact(&mut encrypted_remote_static)
.map_err(|_| NoiseError::MsgTooShort)?;
let aead = Aes256Gcm::new(GenericArray::from_slice(&k));
let nonce = GenericArray::from_slice(&[0u8; AES_NONCE_SIZE]);
let ct_and_ad = Payload {
msg: &encrypted_remote_static,
aad: &h,
};
let rs = aead
.decrypt(nonce, ct_and_ad)
.map_err(|_| NoiseError::Decrypt)?;
let rs = x25519::PublicKey::try_from(rs.as_slice())
.map_err(|_| NoiseError::WrongPublicKeyReceived)?;
mix_hash(&mut h, &encrypted_remote_static);
let dh_output = self.private_key.diffie_hellman(&rs);
let k = mix_key(&mut ck, &dh_output)?;
let offset = cursor.position() as usize;
let received_encrypted_payload = &cursor.into_inner()[offset..];
let aead = Aes256Gcm::new(GenericArray::from_slice(&k));
let nonce = GenericArray::from_slice(&[0u8; AES_NONCE_SIZE]);
let ct_and_ad = Payload {
msg: received_encrypted_payload,
aad: &h,
};
let received_payload = aead
.decrypt(nonce, ct_and_ad)
.map_err(|_| NoiseError::Decrypt)?;
mix_hash(&mut h, received_encrypted_payload);
let handshake_state = ResponderHandshakeState { h, ck, rs, re };
Ok((rs, handshake_state, received_payload))
}
pub fn respond_to_client(
&self,
rng: &mut (impl rand::RngCore + rand::CryptoRng),
handshake_state: ResponderHandshakeState,
payload: Option<&[u8]>,
response_buffer: &mut [u8],
) -> Result<NoiseSession, NoiseError> {
let payload_len = payload.map(<[u8]>::len).unwrap_or(0);
let buffer_size_required = handshake_resp_msg_len(payload_len);
if buffer_size_required > MAX_SIZE_NOISE_MSG {
return Err(NoiseError::PayloadTooLarge);
}
if response_buffer.len() < buffer_size_required {
return Err(NoiseError::ResponseBufferTooSmall);
}
let ResponderHandshakeState {
mut h,
mut ck,
rs,
re,
} = handshake_state;
let e = x25519::PrivateKey::generate(rng);
let e_pub = e.public_key();
mix_hash(&mut h, e_pub.as_slice());
let mut response_buffer = Cursor::new(response_buffer);
response_buffer
.write(e_pub.as_slice())
.map_err(|_| NoiseError::ResponseBufferTooSmall)?;
let dh_output = e.diffie_hellman(&re);
mix_key(&mut ck, &dh_output)?;
let dh_output = e.diffie_hellman(&rs);
let k = mix_key(&mut ck, &dh_output)?;
let aead = Aes256Gcm::new(GenericArray::from_slice(&k));
let msg_and_ad = Payload {
msg: payload.unwrap_or(&[]),
aad: &h,
};
let nonce = GenericArray::from_slice(&[0u8; AES_NONCE_SIZE]);
let encrypted_payload = aead
.encrypt(nonce, msg_and_ad)
.map_err(|_| NoiseError::Encrypt)?;
mix_hash(&mut h, &encrypted_payload);
response_buffer
.write(&encrypted_payload)
.map_err(|_| NoiseError::ResponseBufferTooSmall)?;
let (k1, k2) = hkdf(&ck, None)?;
let session = NoiseSession::new(k2, k1, rs);
Ok(session)
}
pub fn respond_to_client_and_finalize(
&self,
rng: &mut (impl rand::RngCore + rand::CryptoRng),
prologue: &[u8],
received_message: &[u8],
payload: Option<&[u8]>,
response_buffer: &mut [u8],
) -> Result<
(
Vec<u8>, // the payload the initiator sent
NoiseSession, // The created session
),
NoiseError,
> {
let (_, handshake_state, received_payload) =
self.parse_client_init_message(prologue, received_message)?;
let session = self.respond_to_client(rng, handshake_state, payload, response_buffer)?;
Ok((received_payload, session))
}
}
#[cfg_attr(test, derive(Clone))]
pub struct NoiseSession {
valid: bool,
remote_public_key: x25519::PublicKey,
write_key: Vec<u8>,
write_nonce: u64,
read_key: Vec<u8>,
read_nonce: u64,
}
impl NoiseSession {
fn new(write_key: Vec<u8>, read_key: Vec<u8>, remote_public_key: x25519::PublicKey) -> Self {
Self {
valid: true,
remote_public_key,
write_key,
write_nonce: 0,
read_key,
read_nonce: 0,
}
}
#[cfg(any(test, feature = "fuzzing"))]
pub fn new_for_testing() -> Self {
Self::new(
vec![0u8; 32],
vec![0u8; 32],
[0u8; x25519::PUBLIC_KEY_SIZE].into(),
)
}
pub fn get_remote_static(&self) -> x25519::PublicKey {
self.remote_public_key
}
pub fn write_message_in_place(&mut self, message: &mut [u8]) -> Result<Vec<u8>, NoiseError> {
if !self.valid {
return Err(NoiseError::SessionClosed);
}
if message.len() > MAX_SIZE_NOISE_MSG - AES_GCM_TAGLEN {
return Err(NoiseError::PayloadTooLarge);
}
let aead = Aes256Gcm::new(GenericArray::from_slice(&self.write_key));
let mut nonce = [0u8; 4].to_vec();
nonce.extend_from_slice(&self.write_nonce.to_be_bytes());
let nonce = GenericArray::from_slice(&nonce);
let authentication_tag = aead
.encrypt_in_place_detached(nonce, b"", message)
.map_err(|_| NoiseError::Encrypt)?;
self.write_nonce = self
.write_nonce
.checked_add(1)
.ok_or(NoiseError::NonceOverflow)?;
Ok(authentication_tag.to_vec())
}
pub fn read_message_in_place<'a>(
&mut self,
message: &'a mut [u8],
) -> Result<&'a [u8], NoiseError> {
if !self.valid {
return Err(NoiseError::SessionClosed);
}
if message.len() > MAX_SIZE_NOISE_MSG {
self.valid = false;
return Err(NoiseError::ReceivedMsgTooLarge);
}
if message.len() < AES_GCM_TAGLEN {
self.valid = false;
return Err(NoiseError::ResponseBufferTooSmall);
}
let aead = Aes256Gcm::new(GenericArray::from_slice(&self.read_key));
let mut nonce = [0u8; 4].to_vec();
nonce.extend_from_slice(&self.read_nonce.to_be_bytes());
let nonce = GenericArray::from_slice(&nonce);
let (buffer, authentication_tag) = message.split_at_mut(message.len() - AES_GCM_TAGLEN);
let authentication_tag = GenericArray::from_slice(authentication_tag);
aead.decrypt_in_place_detached(nonce, b"", buffer, authentication_tag)
.map_err(|_| {
self.valid = false;
NoiseError::Decrypt
})?;
self.read_nonce = self
.read_nonce
.checked_add(1)
.ok_or(NoiseError::NonceOverflow)?;
Ok(buffer)
}
}
impl std::fmt::Debug for NoiseSession {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "NoiseSession[...]")
}
}