use crate::{
libp2p::{
peer_id::{PeerId, PublicKey, SignatureVerifyFailed},
read_write::{self, ReadWrite},
},
util::protobuf,
};
use alloc::{boxed::Box, collections::VecDeque, vec, vec::Vec};
use core::{cmp, fmt, iter, mem, ops};
pub const PROTOCOL_NAME: &str = "/noise";
pub struct NoiseKey {
private_key: zeroize::Zeroizing<x25519_dalek::StaticSecret>,
public_key: x25519_dalek::PublicKey,
handshake_message: Vec<u8>,
libp2p_public_ed25519_key: [u8; 32],
}
impl NoiseKey {
pub fn new(libp2p_ed25519_private_key: &[u8; 32], noise_static_private_key: &[u8; 32]) -> Self {
let unsigned = UnsignedNoiseKey::from_private_key(noise_static_private_key);
let (libp2p_public_key, signature) = {
let secret = ed25519_zebra::SigningKey::from(*libp2p_ed25519_private_key);
let public = ed25519_zebra::VerificationKey::from(&secret);
let signature = secret.sign(&unsigned.payload_to_sign_as_vec());
(public, signature)
};
unsigned.sign(libp2p_public_key.into(), signature.into())
}
pub fn libp2p_public_ed25519_key(&self) -> &[u8; 32] {
&self.libp2p_public_ed25519_key
}
}
pub struct UnsignedNoiseKey {
private_key: Option<zeroize::Zeroizing<x25519_dalek::StaticSecret>>,
public_key: x25519_dalek::PublicKey,
}
impl UnsignedNoiseKey {
pub fn from_private_key(private_key: &[u8; 32]) -> Self {
let private_key = zeroize::Zeroizing::new(x25519_dalek::StaticSecret::from(*private_key));
let public_key = x25519_dalek::PublicKey::from(&*private_key);
UnsignedNoiseKey {
private_key: Some(private_key),
public_key,
}
}
pub fn payload_to_sign(&self) -> impl Iterator<Item = impl AsRef<[u8]>> {
[
&b"noise-libp2p-static-key:"[..],
&self.public_key.as_bytes()[..],
]
.into_iter()
}
pub fn payload_to_sign_as_vec(&self) -> Vec<u8> {
self.payload_to_sign().fold(Vec::new(), |mut a, b| {
a.extend_from_slice(b.as_ref());
a
})
}
pub fn sign(mut self, libp2p_public_ed25519_key: [u8; 32], signature: [u8; 64]) -> NoiseKey {
let libp2p_pubkey_protobuf =
PublicKey::Ed25519(libp2p_public_ed25519_key).to_protobuf_encoding();
let handshake_message = {
let mut msg = Vec::with_capacity(32 + libp2p_pubkey_protobuf.len() + signature.len());
for slice in protobuf::bytes_tag_encode(1, &libp2p_pubkey_protobuf) {
msg.extend_from_slice(slice.as_ref());
}
for slice in protobuf::bytes_tag_encode(2, &signature) {
msg.extend_from_slice(slice.as_ref());
}
msg
};
NoiseKey {
public_key: self.public_key,
private_key: self.private_key.take().unwrap(),
libp2p_public_ed25519_key,
handshake_message,
}
}
}
pub struct Config<'a> {
pub key: &'a NoiseKey,
pub ephemeral_secret_key: &'a [u8; 32],
pub is_initiator: bool,
pub prologue: &'a [u8],
}
pub struct Noise {
is_initiator: bool,
out_cipher_state: CipherState,
in_cipher_state: CipherState,
next_in_message_size: Option<u16>,
rx_buffer_decrypted: Vec<u8>,
inner_stream_expected_incoming_bytes: usize,
}
impl Noise {
pub fn is_initiator(&self) -> bool {
self.is_initiator
}
pub fn read_write<'a, TNow: Clone>(
&'a mut self,
outer_read_write: &'a mut ReadWrite<TNow>,
) -> Result<InnerReadWrite<'a, TNow>, CipherError> {
while self.rx_buffer_decrypted.is_empty()
|| self.inner_stream_expected_incoming_bytes > self.rx_buffer_decrypted.len()
{
if let Some(next_in_message_size) = self.next_in_message_size {
if let Ok(Some(encrypted_message)) =
outer_read_write.incoming_bytes_take(usize::from(next_in_message_size))
{
self.next_in_message_size = None;
self.in_cipher_state.read_chachapoly_message_to_vec_append(
&[],
&encrypted_message,
&mut self.rx_buffer_decrypted,
)?;
} else {
break;
}
} else if let Ok(Some(next_frame_length)) =
outer_read_write.incoming_bytes_take_array::<2>()
{
self.next_in_message_size = Some(u16::from_be_bytes(next_frame_length));
} else {
break;
}
}
if self.out_cipher_state.nonce_has_overflowed {
return Err(CipherError::NonceOverflow);
}
Ok(InnerReadWrite {
inner_read_write: ReadWrite {
now: outer_read_write.now.clone(),
incoming_buffer: mem::take(&mut self.rx_buffer_decrypted),
read_bytes: 0,
expected_incoming_bytes: if outer_read_write.expected_incoming_bytes.is_some()
|| !outer_read_write.incoming_buffer.is_empty()
{
Some(self.inner_stream_expected_incoming_bytes)
} else {
None
},
write_buffers: Vec::new(),
write_bytes_queued: 0,
write_bytes_queueable: outer_read_write.write_bytes_queueable.map(
|outer_writable| cmp::min(outer_writable.saturating_sub(16 + 2), 65535 - 16),
),
wake_up_after: outer_read_write.wake_up_after.clone(),
},
noise: self,
outer_read_write,
})
}
}
impl fmt::Debug for Noise {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Noise").finish()
}
}
pub struct InnerReadWrite<'a, TNow: Clone> {
noise: &'a mut Noise,
outer_read_write: &'a mut ReadWrite<TNow>,
inner_read_write: ReadWrite<TNow>,
}
impl<'a, TNow: Clone> ops::Deref for InnerReadWrite<'a, TNow> {
type Target = ReadWrite<TNow>;
fn deref(&self) -> &Self::Target {
&self.inner_read_write
}
}
impl<'a, TNow: Clone> ops::DerefMut for InnerReadWrite<'a, TNow> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner_read_write
}
}
impl<'a, TNow: Clone> Drop for InnerReadWrite<'a, TNow> {
fn drop(&mut self) {
self.outer_read_write.wake_up_after = self.inner_read_write.wake_up_after.clone();
self.noise.rx_buffer_decrypted = mem::take(&mut self.inner_read_write.incoming_buffer);
self.noise.inner_stream_expected_incoming_bytes =
self.inner_read_write.expected_incoming_bytes.unwrap_or(0);
if self.inner_read_write.read_bytes != 0 {
self.outer_read_write.wake_up_asap();
}
if self
.inner_read_write
.write_buffers
.iter()
.any(|b| !b.is_empty())
{
self.outer_read_write
.write_buffers
.reserve(2 + self.inner_read_write.write_buffers.len() * 2);
let message_length_prefix_index = self.outer_read_write.write_buffers.len();
self.outer_read_write.write_buffers.push(Vec::new());
let mut total_size = 0;
for encrypted_buffer in self
.noise
.out_cipher_state
.write_chachapoly_message(&[], self.inner_read_write.write_buffers.drain(..))
.unwrap_or_else(|_| unreachable!())
{
total_size += encrypted_buffer.len();
self.outer_read_write.write_buffers.push(encrypted_buffer);
}
let message_length_prefix = u16::try_from(total_size).unwrap().to_be_bytes().to_vec();
self.outer_read_write.write_buffers[message_length_prefix_index] =
message_length_prefix;
self.outer_read_write.write_bytes_queued += total_size + 2;
*self
.outer_read_write
.write_bytes_queueable
.as_mut()
.unwrap() -= total_size + 2;
}
}
}
#[derive(Debug)]
pub enum NoiseHandshake {
InProgress(HandshakeInProgress),
Success {
cipher: Noise,
remote_peer_id: PeerId,
},
}
pub struct HandshakeInProgress(Box<HandshakeInProgressInner>);
struct HandshakeInProgressInner {
is_initiator: bool,
pending_out_data: VecDeque<u8>,
next_in_message_size: Option<u16>,
num_buffered_or_transmitted_messages: u8,
cipher_state: CipherState,
chaining_key: zeroize::Zeroizing<[u8; 32]>,
hash: zeroize::Zeroizing<[u8; 32]>,
local_ephemeral_private_key: zeroize::Zeroizing<x25519_dalek::StaticSecret>,
local_static_private_key: zeroize::Zeroizing<x25519_dalek::StaticSecret>,
local_static_public_key: x25519_dalek::PublicKey,
remote_ephemeral_public_key: x25519_dalek::PublicKey,
remote_static_public_key: x25519_dalek::PublicKey,
remote_public_key: Option<PublicKey>,
libp2p_handshake_message: Vec<u8>,
}
impl NoiseHandshake {
pub fn new(config: Config) -> Self {
NoiseHandshake::InProgress(HandshakeInProgress::new(config))
}
}
impl HandshakeInProgress {
pub fn new(config: Config) -> Self {
let local_ephemeral_private_key = zeroize::Zeroizing::new(
x25519_dalek::StaticSecret::from(*config.ephemeral_secret_key),
);
let mut hash = zeroize::Zeroizing::new([0u8; 32]);
{
const PROTOCOL_NAME: &[u8] = b"Noise_XX_25519_ChaChaPoly_SHA256";
if PROTOCOL_NAME.len() <= hash.len() {
hash[..PROTOCOL_NAME.len()].copy_from_slice(PROTOCOL_NAME);
hash[PROTOCOL_NAME.len()..].fill(0);
} else {
let mut hasher = <sha2::Sha256 as sha2::Digest>::new();
sha2::Digest::update(&mut hasher, PROTOCOL_NAME);
sha2::Digest::finalize_into(
hasher,
sha2::digest::generic_array::GenericArray::from_mut_slice(&mut *hash),
);
}
}
let chaining_key = hash.clone();
mix_hash(&mut hash, config.prologue);
HandshakeInProgress(Box::new(HandshakeInProgressInner {
cipher_state: CipherState {
key: zeroize::Zeroizing::new([0; 32]),
nonce: 0,
nonce_has_overflowed: false,
},
chaining_key,
hash,
local_ephemeral_private_key,
local_static_private_key: config.key.private_key.clone(),
local_static_public_key: config.key.public_key,
remote_ephemeral_public_key: x25519_dalek::PublicKey::from([0; 32]),
remote_static_public_key: x25519_dalek::PublicKey::from([0; 32]),
remote_public_key: None,
is_initiator: config.is_initiator,
pending_out_data: VecDeque::with_capacity(usize::from(u16::MAX) + 2),
next_in_message_size: None,
num_buffered_or_transmitted_messages: 0,
libp2p_handshake_message: config.key.handshake_message.clone(),
}))
}
pub fn read_write<TNow>(
mut self,
read_write: &mut ReadWrite<TNow>,
) -> Result<NoiseHandshake, HandshakeError> {
loop {
read_write.write_from_vec_deque(&mut self.0.pending_out_data);
if !self.0.pending_out_data.is_empty() {
if read_write.write_bytes_queueable.is_none() {
return Err(HandshakeError::WriteClosed);
}
return Ok(NoiseHandshake::InProgress(self));
}
if self.0.num_buffered_or_transmitted_messages == 3 {
debug_assert!(self.0.pending_out_data.is_empty());
debug_assert!(self.0.next_in_message_size.is_none());
let HkdfOutput {
output1: init_to_resp,
output2: resp_to_init,
} = hkdf(&self.0.chaining_key, &[]);
let (out_key, in_key) = match self.0.is_initiator {
true => (init_to_resp, resp_to_init),
false => (resp_to_init, init_to_resp),
};
return Ok(NoiseHandshake::Success {
cipher: Noise {
is_initiator: self.0.is_initiator,
out_cipher_state: CipherState {
key: out_key,
nonce: 0,
nonce_has_overflowed: false,
},
in_cipher_state: CipherState {
key: in_key,
nonce: 0,
nonce_has_overflowed: false,
},
rx_buffer_decrypted: Vec::with_capacity(65535 - 16),
next_in_message_size: None,
inner_stream_expected_incoming_bytes: 0,
},
remote_peer_id: {
self.0
.remote_public_key
.take()
.unwrap_or_else(|| unreachable!())
.into_peer_id()
},
});
}
match (
self.0.num_buffered_or_transmitted_messages,
self.0.is_initiator,
) {
(0, true) => {
let local_ephemeral_public_key =
x25519_dalek::PublicKey::from(&*self.0.local_ephemeral_private_key);
self.0
.pending_out_data
.extend(local_ephemeral_public_key.as_bytes());
mix_hash(&mut self.0.hash, local_ephemeral_public_key.as_bytes());
mix_hash(&mut self.0.hash, &[]);
let len = u16::try_from(self.0.pending_out_data.len())
.unwrap()
.to_be_bytes();
self.0.pending_out_data.push_front(len[1]);
self.0.pending_out_data.push_front(len[0]);
self.0.num_buffered_or_transmitted_messages += 1;
continue;
}
(1, false) => {
let local_ephemeral_public_key =
x25519_dalek::PublicKey::from(&*self.0.local_ephemeral_private_key);
self.0
.pending_out_data
.extend(local_ephemeral_public_key.as_bytes());
mix_hash(&mut self.0.hash, local_ephemeral_public_key.as_bytes());
let HkdfOutput {
output1: chaining_key_update,
output2: key_update,
} = hkdf(
&self.0.chaining_key,
self.0
.local_ephemeral_private_key
.diffie_hellman(&self.0.remote_ephemeral_public_key)
.as_bytes(),
);
self.0.chaining_key = chaining_key_update;
self.0.cipher_state.key = key_update;
self.0.cipher_state.nonce = 0;
let encrypted_static_public_key = self
.0
.cipher_state
.write_chachapoly_message_to_vec(
&*self.0.hash,
self.0.local_static_public_key.as_bytes(),
)
.unwrap();
self.0
.pending_out_data
.extend(encrypted_static_public_key.iter().copied());
mix_hash(&mut self.0.hash, &encrypted_static_public_key);
let HkdfOutput {
output1: chaining_key_update,
output2: key_update,
} = hkdf(
&self.0.chaining_key,
self.0
.local_static_private_key
.diffie_hellman(&self.0.remote_ephemeral_public_key)
.as_bytes(),
);
self.0.chaining_key = chaining_key_update;
self.0.cipher_state.key = key_update;
self.0.cipher_state.nonce = 0;
let encrypted_libp2p_handshake = self
.0
.cipher_state
.write_chachapoly_message_to_vec(
&*self.0.hash,
&self.0.libp2p_handshake_message,
)
.unwrap();
self.0
.pending_out_data
.extend(encrypted_libp2p_handshake.iter().copied());
mix_hash(&mut self.0.hash, &encrypted_libp2p_handshake);
let len = u16::try_from(self.0.pending_out_data.len())
.unwrap()
.to_be_bytes();
self.0.pending_out_data.push_front(len[1]);
self.0.pending_out_data.push_front(len[0]);
self.0.num_buffered_or_transmitted_messages += 1;
continue;
}
(2, true) => {
let encrypted_static_public_key = self
.0
.cipher_state
.write_chachapoly_message_to_vec(
&*self.0.hash,
self.0.local_static_public_key.as_bytes(),
)
.unwrap();
self.0
.pending_out_data
.extend(encrypted_static_public_key.iter().copied());
mix_hash(&mut self.0.hash, &encrypted_static_public_key);
let HkdfOutput {
output1: chaining_key_update,
output2: key_update,
} = hkdf(
&self.0.chaining_key,
self.0
.local_static_private_key
.diffie_hellman(&self.0.remote_ephemeral_public_key)
.as_bytes(),
);
self.0.chaining_key = chaining_key_update;
self.0.cipher_state.key = key_update;
self.0.cipher_state.nonce = 0;
let encrypted_libp2p_handshake = self
.0
.cipher_state
.write_chachapoly_message_to_vec(
&*self.0.hash,
&self.0.libp2p_handshake_message,
)
.unwrap();
self.0
.pending_out_data
.extend(encrypted_libp2p_handshake.iter().copied());
mix_hash(&mut self.0.hash, &encrypted_libp2p_handshake);
let len = u16::try_from(self.0.pending_out_data.len())
.unwrap()
.to_be_bytes();
self.0.pending_out_data.push_front(len[1]);
self.0.pending_out_data.push_front(len[0]);
self.0.num_buffered_or_transmitted_messages += 1;
continue;
}
_ => {}
}
let next_in_message_size =
if let Some(next_in_message_size) = self.0.next_in_message_size {
next_in_message_size
} else {
match read_write.incoming_bytes_take(2) {
Ok(Some(size_buffer)) => *self.0.next_in_message_size.insert(
u16::from_be_bytes(<[u8; 2]>::try_from(&size_buffer[..2]).unwrap()),
),
Ok(None) => {
return Ok(NoiseHandshake::InProgress(self));
}
Err(read_write::IncomingBytesTakeError::ReadClosed) => {
return Err(HandshakeError::ReadClosed);
}
}
};
let available_message =
match read_write.incoming_bytes_take(usize::from(next_in_message_size)) {
Ok(Some(available_message)) => {
self.0.next_in_message_size = None;
available_message
}
Ok(None) => {
return Ok(NoiseHandshake::InProgress(self));
}
Err(read_write::IncomingBytesTakeError::ReadClosed) => {
return Err(HandshakeError::ReadClosed);
}
};
match (
self.0.num_buffered_or_transmitted_messages,
self.0.is_initiator,
) {
(0, false) => {
self.0.remote_ephemeral_public_key = x25519_dalek::PublicKey::from(*{
let mut parser =
nom::combinator::all_consuming::<_, (&[u8], nom::error::ErrorKind), _>(
nom::combinator::map(nom::bytes::streaming::take(32u32), |k| {
<&[u8; 32]>::try_from(k).unwrap()
}),
);
match nom::Parser::parse(&mut parser, &available_message) {
Ok((_, out)) => out,
Err(_) => {
return Err(HandshakeError::PayloadDecode(PayloadDecodeError));
}
}
});
mix_hash(
&mut self.0.hash,
self.0.remote_ephemeral_public_key.as_bytes(),
);
mix_hash(&mut self.0.hash, &[]);
self.0.num_buffered_or_transmitted_messages += 1;
continue;
}
(1, true) => {
let (
remote_ephemeral_public_key,
remote_static_public_key_encrypted,
libp2p_handshake_encrypted,
) = {
let mut parser =
nom::combinator::all_consuming::<_, (&[u8], nom::error::ErrorKind), _>(
(
nom::combinator::map(nom::bytes::streaming::take(32u32), |k| {
<&[u8; 32]>::try_from(k).unwrap()
}),
nom::combinator::map(nom::bytes::streaming::take(48u32), |k| {
<&[u8; 48]>::try_from(k).unwrap()
}),
nom::combinator::rest,
),
);
match nom::Parser::parse(&mut parser, &available_message) {
Ok((_, out)) => out,
Err(_) => {
return Err(HandshakeError::PayloadDecode(PayloadDecodeError));
}
}
};
self.0.remote_ephemeral_public_key =
x25519_dalek::PublicKey::from(*remote_ephemeral_public_key);
mix_hash(
&mut self.0.hash,
self.0.remote_ephemeral_public_key.as_bytes(),
);
let HkdfOutput {
output1: chaining_key_update,
output2: key_update,
} = hkdf(
&self.0.chaining_key,
self.0
.local_ephemeral_private_key
.diffie_hellman(&self.0.remote_ephemeral_public_key)
.as_bytes(),
);
self.0.chaining_key = chaining_key_update;
self.0.cipher_state.key = key_update;
self.0.cipher_state.nonce = 0;
self.0.remote_static_public_key = x25519_dalek::PublicKey::from(
self.0
.cipher_state
.read_chachapoly_message_to_array(
&*self.0.hash,
remote_static_public_key_encrypted,
)
.map_err(HandshakeError::Cipher)?,
);
mix_hash(&mut self.0.hash, remote_static_public_key_encrypted);
let HkdfOutput {
output1: chaining_key_update,
output2: key_update,
} = hkdf(
&self.0.chaining_key,
self.0
.local_ephemeral_private_key
.diffie_hellman(&self.0.remote_static_public_key)
.as_bytes(),
);
self.0.chaining_key = chaining_key_update;
self.0.cipher_state.key = key_update;
self.0.cipher_state.nonce = 0;
self.0.remote_public_key = Some({
let libp2p_handshake_decrypted = self
.0
.cipher_state
.read_chachapoly_message_to_vec(
&*self.0.hash,
libp2p_handshake_encrypted,
)
.map_err(HandshakeError::Cipher)?;
let (libp2p_key, libp2p_signature) = {
let mut parser =
nom::combinator::all_consuming::<
_,
(&[u8], nom::error::ErrorKind),
_,
>(protobuf::message_decode! {
#[required] key = 1 => protobuf::bytes_tag_decode,
#[required] sig = 2 => protobuf::bytes_tag_decode,
});
match nom::Parser::parse(&mut parser, &libp2p_handshake_decrypted) {
Ok((_, out)) => (out.key, out.sig),
Err(_) => {
return Err(HandshakeError::PayloadDecode(PayloadDecodeError));
}
}
};
let remote_public_key = PublicKey::from_protobuf_encoding(libp2p_key)
.map_err(|_| HandshakeError::InvalidKey)?;
remote_public_key
.verify(
&[
&b"noise-libp2p-static-key:"[..],
&self.0.remote_static_public_key.as_bytes()[..],
]
.concat(),
libp2p_signature,
)
.map_err(HandshakeError::SignatureVerificationFailed)?;
remote_public_key
});
mix_hash(&mut self.0.hash, libp2p_handshake_encrypted);
self.0.num_buffered_or_transmitted_messages += 1;
continue;
}
(2, false) => {
let (remote_static_public_key_encrypted, libp2p_handshake_encrypted) = {
let mut parser =
nom::combinator::all_consuming::<_, (&[u8], nom::error::ErrorKind), _>(
(
nom::combinator::map(nom::bytes::streaming::take(48u32), |k| {
<&[u8; 48]>::try_from(k).unwrap()
}),
nom::combinator::rest,
),
);
match nom::Parser::parse(&mut parser, &available_message) {
Ok((_, out)) => out,
Err(_) => {
return Err(HandshakeError::PayloadDecode(PayloadDecodeError));
}
}
};
self.0.remote_static_public_key = x25519_dalek::PublicKey::from(
self.0
.cipher_state
.read_chachapoly_message_to_array(
&*self.0.hash,
remote_static_public_key_encrypted,
)
.map_err(HandshakeError::Cipher)?,
);
mix_hash(&mut self.0.hash, remote_static_public_key_encrypted);
let HkdfOutput {
output1: chaining_key_update,
output2: key_update,
} = hkdf(
&self.0.chaining_key,
self.0
.local_ephemeral_private_key
.clone()
.diffie_hellman(&self.0.remote_static_public_key)
.as_bytes(),
);
self.0.chaining_key = chaining_key_update;
self.0.cipher_state.key = key_update;
self.0.cipher_state.nonce = 0;
self.0.remote_public_key = Some({
let libp2p_handshake_decrypted = self
.0
.cipher_state
.read_chachapoly_message_to_vec(
&*self.0.hash,
libp2p_handshake_encrypted,
)
.map_err(HandshakeError::Cipher)?;
let (libp2p_key, libp2p_signature) = {
let mut parser =
nom::combinator::all_consuming::<
_,
(&[u8], nom::error::ErrorKind),
_,
>(protobuf::message_decode! {
#[required] key = 1 => protobuf::bytes_tag_decode,
#[required] sig = 2 => protobuf::bytes_tag_decode,
});
match nom::Parser::parse(&mut parser, &libp2p_handshake_decrypted) {
Ok((_, out)) => (out.key, out.sig),
Err(_) => {
return Err(HandshakeError::PayloadDecode(PayloadDecodeError));
}
}
};
let remote_public_key = PublicKey::from_protobuf_encoding(libp2p_key)
.map_err(|_| HandshakeError::InvalidKey)?;
remote_public_key
.verify(
&[
&b"noise-libp2p-static-key:"[..],
&self.0.remote_static_public_key.as_bytes()[..],
]
.concat(),
libp2p_signature,
)
.map_err(HandshakeError::SignatureVerificationFailed)?;
remote_public_key
});
mix_hash(&mut self.0.hash, libp2p_handshake_encrypted);
self.0.num_buffered_or_transmitted_messages += 1;
continue;
}
_ => {
unreachable!()
}
}
}
}
}
impl fmt::Debug for HandshakeInProgress {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("HandshakeInProgress").finish()
}
}
#[derive(Debug, derive_more::Display, derive_more::Error)]
pub enum HandshakeError {
ReadClosed,
WriteClosed,
#[display("Cipher error: {_0}")]
Cipher(CipherError),
#[display("Failed to decode payload as the libp2p-extension-to-noise payload: {_0}")]
PayloadDecode(PayloadDecodeError),
InvalidKey,
#[display("Signature of the noise public key by the libp2p key failed.")]
SignatureVerificationFailed(SignatureVerifyFailed),
}
#[derive(Debug, derive_more::Display, derive_more::Error)]
#[display("Error while encrypting the Noise payload")]
pub enum EncryptError {
NonceOverflow,
}
#[derive(Debug, derive_more::Display, derive_more::Error)]
#[display("Error while decrypting the Noise payload")]
pub enum CipherError {
MissingHmac,
HmacInvalid,
NonceOverflow,
}
#[derive(Debug, derive_more::Display, derive_more::Error)]
pub struct PayloadDecodeError;
struct CipherState {
key: zeroize::Zeroizing<[u8; 32]>,
nonce: u64,
nonce_has_overflowed: bool,
}
impl CipherState {
fn write_chachapoly_message(
&mut self,
associated_data: &[u8],
decrypted_buffers: impl Iterator<Item = Vec<u8>>,
) -> Result<impl Iterator<Item = Vec<u8>>, EncryptError> {
if self.nonce_has_overflowed {
return Err(EncryptError::NonceOverflow);
}
let (mut cipher, mac) = self.prepare(associated_data);
let associated_data_len = associated_data.len();
(self.nonce, self.nonce_has_overflowed) = self.nonce.overflowing_add(1);
let mut decrypted_buffers = decrypted_buffers.peekable();
let mut total_decrypted_data = 0;
let mut overlapping_data = Vec::new();
let mut mac = Some(mac);
Ok(iter::from_fn(move || {
loop {
debug_assert!(overlapping_data.len() < 64);
let Some(mac_deref) = mac.as_mut() else {
return None;
};
if !overlapping_data.is_empty() {
if let Some(next_buffer) = decrypted_buffers.peek_mut() {
let missing_data_for_full_frame = 64 - overlapping_data.len();
if next_buffer.len() >= missing_data_for_full_frame {
overlapping_data
.extend_from_slice(&next_buffer[..missing_data_for_full_frame]);
next_buffer.copy_within(missing_data_for_full_frame.., 0);
next_buffer.truncate(next_buffer.len() - missing_data_for_full_frame);
chacha20::cipher::StreamCipher::apply_keystream(
&mut cipher,
&mut overlapping_data,
);
poly1305::universal_hash::UniversalHash::update_padded(
mac_deref,
&overlapping_data,
);
debug_assert_eq!(overlapping_data.len(), 64);
total_decrypted_data += 64;
return Some(mem::take(&mut overlapping_data));
} else {
overlapping_data.extend_from_slice(next_buffer);
let _ = decrypted_buffers.next();
}
} else {
chacha20::cipher::StreamCipher::apply_keystream(
&mut cipher,
&mut overlapping_data,
);
poly1305::universal_hash::UniversalHash::update_padded(
mac_deref,
&overlapping_data,
);
total_decrypted_data += overlapping_data.len();
return Some(mem::take(&mut overlapping_data));
}
} else if let Some(mut buffer) = decrypted_buffers.next() {
let encryptable_in_place = 64 * (buffer.len() / 64);
chacha20::cipher::StreamCipher::apply_keystream(
&mut cipher,
&mut buffer[..encryptable_in_place],
);
poly1305::universal_hash::UniversalHash::update_padded(
mac_deref,
&buffer[..encryptable_in_place],
);
if encryptable_in_place != buffer.len() {
overlapping_data.reserve(64);
overlapping_data.extend_from_slice(&buffer[encryptable_in_place..]);
buffer.truncate(encryptable_in_place);
}
total_decrypted_data += encryptable_in_place;
return Some(buffer);
} else {
let mut block =
poly1305::universal_hash::generic_array::GenericArray::default();
block[..8].copy_from_slice(
&u64::try_from(associated_data_len).unwrap().to_le_bytes(),
);
block[8..].copy_from_slice(
&u64::try_from(total_decrypted_data).unwrap().to_le_bytes(),
);
poly1305::universal_hash::UniversalHash::update(mac_deref, &[block]);
let mac_bytes =
poly1305::universal_hash::UniversalHash::finalize(mac.take().unwrap())
.to_vec();
return Some(mac_bytes);
}
}
}))
}
fn write_chachapoly_message_to_vec(
&mut self,
associated_data: &[u8],
data: &[u8],
) -> Result<Vec<u8>, EncryptError> {
Ok(self
.write_chachapoly_message(associated_data, iter::once(data.to_vec()))?
.fold(Vec::new(), |mut a, b| {
if a.is_empty() {
b
} else {
a.extend_from_slice(&b);
a
}
}))
}
fn read_chachapoly_message_to_array(
&mut self,
associated_data: &[u8],
message_data: &[u8; 48],
) -> Result<[u8; 32], CipherError> {
let mut out = [0; 32];
self.read_chachapoly_message_to_slice(associated_data, message_data, &mut out)?;
Ok(out)
}
fn read_chachapoly_message_to_vec(
&mut self,
associated_data: &[u8],
message_data: &[u8],
) -> Result<Vec<u8>, CipherError> {
let mut destination = vec![0; message_data.len().saturating_sub(16)];
self.read_chachapoly_message_to_slice(associated_data, message_data, &mut destination)?;
Ok(destination)
}
fn read_chachapoly_message_to_vec_append(
&mut self,
associated_data: &[u8],
message_data: &[u8],
out: &mut Vec<u8>,
) -> Result<(), CipherError> {
let len_before = out.len();
out.resize(len_before + message_data.len().saturating_sub(16), 0);
let result = self.read_chachapoly_message_to_slice(
associated_data,
message_data,
&mut out[len_before..],
);
if result.is_err() {
out.truncate(len_before);
}
result
}
fn read_chachapoly_message_to_slice(
&mut self,
associated_data: &[u8],
message_data: &[u8],
destination: &mut [u8],
) -> Result<(), CipherError> {
debug_assert_eq!(destination.len(), message_data.len() - 16);
if self.nonce_has_overflowed {
return Err(CipherError::NonceOverflow);
}
if message_data.len() < 16 {
return Err(CipherError::MissingHmac);
}
let (mut cipher, mut mac) = self.prepare(associated_data);
poly1305::universal_hash::UniversalHash::update_padded(
&mut mac,
&message_data[..message_data.len() - 16],
);
let mut block = poly1305::universal_hash::generic_array::GenericArray::default();
block[..8].copy_from_slice(&u64::try_from(associated_data.len()).unwrap().to_le_bytes());
block[8..].copy_from_slice(
&u64::try_from(message_data.len() - 16)
.unwrap()
.to_le_bytes(),
);
poly1305::universal_hash::UniversalHash::update(&mut mac, &[block]);
let obtained_mac_bytes = &message_data[message_data.len() - 16..];
if poly1305::universal_hash::UniversalHash::verify(
mac,
poly1305::universal_hash::generic_array::GenericArray::from_slice(obtained_mac_bytes),
)
.is_err()
{
return Err(CipherError::HmacInvalid);
}
chacha20::cipher::StreamCipher::apply_keystream_b2b(
&mut cipher,
&message_data[..message_data.len() - 16],
destination,
)
.unwrap_or_else(|_| unreachable!());
(self.nonce, self.nonce_has_overflowed) = self.nonce.overflowing_add(1);
Ok(())
}
fn prepare(&self, associated_data: &[u8]) -> (chacha20::ChaCha20, poly1305::Poly1305) {
let mut cipher = {
let nonce = {
let mut out = [0; 12];
out[4..].copy_from_slice(&self.nonce.to_le_bytes());
out
};
<chacha20::ChaCha20 as chacha20::cipher::KeyIvInit>::new(
chacha20::cipher::generic_array::GenericArray::from_slice(&self.key[..]),
chacha20::cipher::generic_array::GenericArray::from_slice(&nonce[..]),
)
};
let mut mac = {
let mut mac_key = zeroize::Zeroizing::new([0u8; 32]);
chacha20::cipher::StreamCipher::apply_keystream(&mut cipher, &mut *mac_key);
chacha20::cipher::StreamCipherSeek::seek(&mut cipher, 64);
<poly1305::Poly1305 as poly1305::universal_hash::KeyInit>::new(
poly1305::universal_hash::generic_array::GenericArray::from_slice(&*mac_key),
)
};
poly1305::universal_hash::UniversalHash::update_padded(&mut mac, associated_data);
(cipher, mac)
}
}
fn mix_hash(hash: &mut [u8; 32], data: &[u8]) {
let mut hasher = <sha2::Sha256 as sha2::Digest>::new();
sha2::Digest::update(&mut hasher, *hash);
sha2::Digest::update(&mut hasher, data);
sha2::Digest::finalize_into(
hasher,
sha2::digest::generic_array::GenericArray::from_mut_slice(hash),
);
}
fn hkdf(chaining_key: &[u8; 32], input_key_material: &[u8]) -> HkdfOutput {
fn hmac_hash<'a>(
key: &[u8; 32],
data: impl IntoIterator<Item = &'a [u8]>,
) -> zeroize::Zeroizing<[u8; 32]> {
let mut ipad = [0x36u8; 64];
let mut opad = [0x5cu8; 64];
for n in 0..key.len() {
ipad[n] ^= key[n];
opad[n] ^= key[n];
}
let intermediary_result = {
let mut hasher = <sha2::Sha256 as sha2::Digest>::new();
sha2::Digest::update(&mut hasher, ipad);
for data in data {
sha2::Digest::update(&mut hasher, data);
}
sha2::Digest::finalize(hasher)
};
let mut hasher = <sha2::Sha256 as sha2::Digest>::new();
sha2::Digest::update(&mut hasher, opad);
sha2::Digest::update(&mut hasher, intermediary_result);
let mut output = zeroize::Zeroizing::new([0; 32]);
sha2::Digest::finalize_into(
hasher,
sha2::digest::generic_array::GenericArray::from_mut_slice(&mut *output),
);
output
}
let temp_key = hmac_hash(chaining_key, [input_key_material]);
let output1 = hmac_hash(&temp_key, [&[0x01][..]]);
let output2 = hmac_hash(&temp_key, [&*output1, &[0x02][..]]);
HkdfOutput { output1, output2 }
}
struct HkdfOutput {
output1: zeroize::Zeroizing<[u8; 32]>,
output2: zeroize::Zeroizing<[u8; 32]>,
}
#[cfg(test)]
mod tests {
use core::{cmp, mem};
use super::{Config, NoiseHandshake, NoiseKey, ReadWrite};
#[test]
fn handshake_basic_works() {
fn test_with_buffer_sizes(mut size1: usize, mut size2: usize) {
let key1 = NoiseKey::new(&rand::random(), &rand::random());
let key2 = NoiseKey::new(&rand::random(), &rand::random());
let mut handshake1 = NoiseHandshake::new(Config {
key: &key1,
is_initiator: true,
prologue: &[],
ephemeral_secret_key: &rand::random(),
});
let mut handshake2 = NoiseHandshake::new(Config {
key: &key2,
is_initiator: false,
prologue: &[],
ephemeral_secret_key: &rand::random(),
});
let mut buf_1_to_2 = Vec::new();
let mut buf_2_to_1 = Vec::new();
while !matches!(
(&handshake1, &handshake2),
(
NoiseHandshake::Success { .. },
NoiseHandshake::Success { .. }
)
) {
match handshake1 {
NoiseHandshake::Success { .. } => {}
NoiseHandshake::InProgress(nego) => {
let mut read_write = ReadWrite {
now: 0,
incoming_buffer: buf_2_to_1,
expected_incoming_bytes: Some(0),
read_bytes: 0,
write_bytes_queued: buf_1_to_2.len(),
write_bytes_queueable: Some(size1 - buf_1_to_2.len()),
write_buffers: vec![mem::take(&mut buf_1_to_2)],
wake_up_after: None,
};
handshake1 = nego.read_write(&mut read_write).unwrap();
buf_2_to_1 = read_write.incoming_buffer;
buf_1_to_2.extend(
read_write
.write_buffers
.drain(..)
.flat_map(|b| b.into_iter()),
);
size2 = cmp::max(size2, read_write.expected_incoming_bytes.unwrap_or(0));
}
}
match handshake2 {
NoiseHandshake::Success { .. } => {}
NoiseHandshake::InProgress(nego) => {
let mut read_write = ReadWrite {
now: 0,
incoming_buffer: buf_1_to_2,
expected_incoming_bytes: Some(0),
read_bytes: 0,
write_bytes_queued: buf_2_to_1.len(),
write_bytes_queueable: Some(size2 - buf_2_to_1.len()),
write_buffers: vec![mem::take(&mut buf_2_to_1)],
wake_up_after: None,
};
handshake2 = nego.read_write(&mut read_write).unwrap();
buf_1_to_2 = read_write.incoming_buffer;
buf_2_to_1.extend(
read_write
.write_buffers
.drain(..)
.flat_map(|b| b.into_iter()),
);
size1 = cmp::max(size1, read_write.expected_incoming_bytes.unwrap_or(0));
}
}
}
}
test_with_buffer_sizes(256, 256);
test_with_buffer_sizes(1, 1);
test_with_buffer_sizes(1, 2048);
test_with_buffer_sizes(2048, 1);
}
}