use byteorder::{ByteOrder, LittleEndian};
use bytes::BytesMut;
use snow::{Builder, HandshakeState, TransportState};
use std::fmt::{self, Error, Formatter};
use super::{handshake::HandshakeParams, resolver::SodiumResolver};
use crate::events::noise::{error::NoiseError, HEADER_LENGTH, MAX_MESSAGE_LENGTH, TAG_LENGTH};
pub const HANDSHAKE_HEADER_LENGTH: usize = 2;
pub const MAX_HANDSHAKE_MESSAGE_LENGTH: usize = 65535;
pub const MIN_HANDSHAKE_MESSAGE_LENGTH: usize = 32;
static PARAMS: &str = "Noise_XK_25519_ChaChaPoly_SHA256";
pub struct NoiseWrapper {
pub state: HandshakeState,
}
impl NoiseWrapper {
pub fn initiator(params: &HandshakeParams) -> Self {
if let Some(ref remote_key) = params.remote_key {
let builder: Builder<'_> = Self::noise_builder()
.local_private_key(params.secret_key.as_ref())
.remote_public_key(remote_key.as_ref());
let state = builder
.build_initiator()
.expect("Noise session initiator failed to initialize");
Self { state }
} else {
panic!("Remote public key is not specified")
}
}
pub fn responder(params: &HandshakeParams) -> Self {
let builder: Builder<'_> = Self::noise_builder();
let state = builder
.local_private_key(params.secret_key.as_ref())
.build_responder()
.expect("Noise session responder failed to initialize");
Self { state }
}
pub fn read_handshake_msg(&mut self, input: &[u8]) -> Result<Vec<u8>, NoiseError> {
if input.len() < MIN_HANDSHAKE_MESSAGE_LENGTH || input.len() > MAX_MESSAGE_LENGTH {
return Err(NoiseError::WrongMessageLength(input.len()));
}
let mut buf = vec![0_u8; MAX_MESSAGE_LENGTH];
let len = self.state.read_message(input, &mut buf)?;
buf.truncate(len);
Ok(buf)
}
pub fn write_handshake_msg(&mut self, msg: &[u8]) -> Result<Vec<u8>, NoiseError> {
let mut buf = vec![0_u8; MAX_MESSAGE_LENGTH];
let len = self.state.write_message(msg, &mut buf)?;
buf.truncate(len);
Ok(buf)
}
pub fn into_transport_wrapper(self) -> Result<TransportWrapper, NoiseError> {
let state = self.state.into_transport_mode()?;
Ok(TransportWrapper { state })
}
fn noise_builder<'a>() -> Builder<'a> {
Builder::with_resolver(PARAMS.parse().unwrap(), Box::new(SodiumResolver::new()))
}
}
impl fmt::Debug for NoiseWrapper {
fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> {
write!(
f,
"NoiseWrapper {{ handshake finished: {} }}",
self.state.is_handshake_finished()
)
}
}
pub struct TransportWrapper {
pub state: TransportState,
}
impl TransportWrapper {
pub fn decrypt_msg(&mut self, len: usize, buf: &mut BytesMut) -> anyhow::Result<BytesMut> {
debug_assert!(len + HEADER_LENGTH <= buf.len());
let data = buf.split_to(len + HEADER_LENGTH).to_vec();
let data = &data[HEADER_LENGTH..];
let len = decrypted_msg_len(data.len());
let mut decrypted_message = BytesMut::with_capacity(len);
decrypted_message.resize(len, 0);
let mut read = vec![0_u8; MAX_MESSAGE_LENGTH];
for (i, msg) in data.chunks(MAX_MESSAGE_LENGTH).enumerate() {
let len = self.state.read_message(msg, &mut read)?;
let start = i * (MAX_MESSAGE_LENGTH - TAG_LENGTH);
let end = start + len;
decrypted_message[start..end].copy_from_slice(&read[..len]);
}
Ok(decrypted_message)
}
pub fn encrypt_msg(&mut self, msg: &[u8], buf: &mut BytesMut) -> anyhow::Result<()> {
const CHUNK_LENGTH: usize = MAX_MESSAGE_LENGTH - TAG_LENGTH;
let len = encrypted_msg_len(msg.len());
let mut encrypted_message = vec![0; len + HEADER_LENGTH];
LittleEndian::write_u32(&mut encrypted_message[..HEADER_LENGTH], len as u32);
let mut written = vec![0_u8; MAX_MESSAGE_LENGTH];
for (i, msg) in msg.chunks(CHUNK_LENGTH).enumerate() {
let len = self.state.write_message(msg, &mut written)?;
let start = HEADER_LENGTH + i * MAX_MESSAGE_LENGTH;
let end = start + len;
encrypted_message[start..end].copy_from_slice(&written[..len]);
}
buf.extend_from_slice(&encrypted_message);
Ok(())
}
}
fn decrypted_msg_len(raw_message_len: usize) -> usize {
raw_message_len - TAG_LENGTH * div_ceil(raw_message_len, MAX_MESSAGE_LENGTH)
}
fn encrypted_msg_len(raw_message_len: usize) -> usize {
let tag_count = div_ceil(raw_message_len, MAX_MESSAGE_LENGTH - TAG_LENGTH);
raw_message_len + TAG_LENGTH * tag_count
}
fn div_ceil(lhs: usize, rhs: usize) -> usize {
match (lhs / rhs, lhs % rhs) {
(d, r) if (r == 0) => d,
(d, _) => d + 1,
}
}
impl fmt::Debug for TransportWrapper {
fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> {
write!(
f,
"TransportWrapper {{ is initiator: {} }}",
self.state.is_initiator()
)
}
}