use crate::{
common::drbg::{self, Drbg, Seed},
constants::MESSAGE_OVERHEAD,
framing::{FrameError, Messages},
};
use bytes::{Buf, BufMut, BytesMut};
use crypto_secretbox::{
aead::{generic_array::GenericArray, Aead, KeyInit},
XSalsa20Poly1305,
};
use ptrs::{debug, error, trace};
use rand::prelude::*;
use tokio_util::codec::{Decoder, Encoder};
pub(crate) const MAX_SEGMENT_LENGTH: usize = 1500 - (40 + 12);
const SECRET_BOX_OVERHEAD: usize = TAG_SIZE;
pub(crate) const FRAME_OVERHEAD: usize = LENGTH_LENGTH + SECRET_BOX_OVERHEAD;
pub(crate) const MAX_FRAME_PAYLOAD_LENGTH: usize = MAX_SEGMENT_LENGTH - FRAME_OVERHEAD;
pub(crate) const MAX_FRAME_LENGTH: usize = MAX_SEGMENT_LENGTH - LENGTH_LENGTH;
pub(crate) const MIN_FRAME_LENGTH: usize = FRAME_OVERHEAD - LENGTH_LENGTH;
pub(crate) const NONCE_PREFIX_LENGTH: usize = 16;
pub(crate) const NONCE_COUNTER_LENGTH: usize = 8;
pub(crate) const NONCE_LENGTH: usize = NONCE_PREFIX_LENGTH + NONCE_COUNTER_LENGTH;
pub(crate) const LENGTH_LENGTH: usize = 2;
pub(crate) const KEY_LENGTH: usize = 32;
pub(crate) const TAG_SIZE: usize = 16;
pub(crate) const KEY_MATERIAL_LENGTH: usize = KEY_LENGTH + NONCE_PREFIX_LENGTH + drbg::SEED_LENGTH;
pub struct EncryptingCodec {
encoder: EncryptingEncoder,
decoder: EncryptingDecoder,
pub(crate) handshake_complete: bool,
}
impl EncryptingCodec {
pub fn new(
encoder_key_material: [u8; KEY_MATERIAL_LENGTH],
decoder_key_material: [u8; KEY_MATERIAL_LENGTH],
) -> Self {
Self {
encoder: EncryptingEncoder::new(encoder_key_material),
decoder: EncryptingDecoder::new(decoder_key_material),
handshake_complete: false,
}
}
pub(crate) fn handshake_complete(&mut self) {
self.handshake_complete = true;
}
}
struct EncryptingDecoder {
key: [u8; KEY_LENGTH],
nonce: NonceBox,
drbg: Drbg,
next_nonce: [u8; NONCE_LENGTH],
next_length: u16,
next_length_invalid: bool,
}
impl EncryptingDecoder {
fn new(key_material: [u8; KEY_MATERIAL_LENGTH]) -> Self {
trace!("new decoder key_material: {}", hex::encode(key_material));
let key: [u8; KEY_LENGTH] = key_material[..KEY_LENGTH].try_into().unwrap();
let nonce = NonceBox::new(&key_material[KEY_LENGTH..(KEY_LENGTH + NONCE_PREFIX_LENGTH)]);
let seed = Seed::try_from(&key_material[(KEY_LENGTH + NONCE_PREFIX_LENGTH)..]).unwrap();
let d = Drbg::new(Some(seed)).unwrap();
Self {
key,
drbg: d,
nonce,
next_nonce: [0_u8; NONCE_LENGTH],
next_length: 0,
next_length_invalid: false,
}
}
}
impl Decoder for EncryptingCodec {
type Item = Messages;
type Error = FrameError;
fn decode(
&mut self,
src: &mut BytesMut,
) -> std::result::Result<Option<Self::Item>, Self::Error> {
trace!(
"decoding src:{}B {} {}",
src.remaining(),
self.decoder.next_length,
self.decoder.next_length_invalid
);
if self.decoder.next_length == 0 {
if LENGTH_LENGTH > src.remaining() {
return Ok(None);
}
self.decoder.next_nonce = self.decoder.nonce.next()?;
let mut length = src.get_u16();
let length_mask = self.decoder.drbg.length_mask();
trace!(
"decoding {length:04x}^{length_mask:04x} {:04x}B",
length ^ length_mask
);
length ^= length_mask;
if MAX_FRAME_LENGTH < length as usize || MIN_FRAME_LENGTH > length as usize {
let invalid_length = length;
self.decoder.next_length_invalid = true;
length = rand::thread_rng().gen::<u16>()
% (MAX_FRAME_LENGTH - MIN_FRAME_LENGTH) as u16
+ MIN_FRAME_LENGTH as u16;
error!(
"invalid length {invalid_length} {length} {}",
self.decoder.next_length_invalid
);
}
self.decoder.next_length = length;
}
let next_len = self.decoder.next_length as usize;
if next_len > src.len() {
if !self.decoder.next_length_invalid {
src.reserve(next_len - src.len());
}
trace!(
"next_len > src.len --> reading more {} {}",
self.decoder.next_length,
self.decoder.next_length_invalid
);
return Ok(None);
}
let data = src.get(..next_len).unwrap().to_vec();
let key = GenericArray::from_slice(&self.decoder.key);
let cipher = XSalsa20Poly1305::new(key);
let nonce = GenericArray::from_slice(&self.decoder.next_nonce);
let res = cipher.decrypt(nonce, data.as_ref());
if res.is_err() {
let e = res.unwrap_err();
trace!("failed to decrypt result: {e}");
return Err(e.into());
}
let plaintext = res?;
if plaintext.len() < MESSAGE_OVERHEAD {
return Err(FrameError::InvalidMessage);
}
self.decoder.next_length = 0;
src.advance(next_len);
debug!("decoding {next_len}B src:{}B", src.remaining());
match Messages::try_parse(&mut BytesMut::from(plaintext.as_slice())) {
Ok(Messages::Padding(_)) => Ok(None),
Ok(m) => Ok(Some(m)),
Err(FrameError::UnknownMessageType(_)) => Ok(None),
Err(e) => Err(e),
}
}
}
struct EncryptingEncoder {
key: [u8; KEY_LENGTH],
nonce: NonceBox,
drbg: Drbg,
}
impl EncryptingEncoder {
fn new(key_material: [u8; KEY_MATERIAL_LENGTH]) -> Self {
trace!("new encoder key_material: {}", hex::encode(key_material));
let key: [u8; KEY_LENGTH] = key_material[..KEY_LENGTH].try_into().unwrap();
let nonce = NonceBox::new(&key_material[KEY_LENGTH..(KEY_LENGTH + NONCE_PREFIX_LENGTH)]);
let seed = Seed::try_from(&key_material[(KEY_LENGTH + NONCE_PREFIX_LENGTH)..]).unwrap();
let d = Drbg::new(Some(seed)).unwrap();
Self {
key,
nonce,
drbg: d,
}
}
}
impl<T: Buf> Encoder<T> for EncryptingCodec {
type Error = FrameError;
fn encode(&mut self, plaintext: T, dst: &mut BytesMut) -> std::result::Result<(), Self::Error> {
trace!(
"encoding {}/{MAX_FRAME_PAYLOAD_LENGTH}",
plaintext.remaining()
);
if plaintext.remaining() > MAX_FRAME_PAYLOAD_LENGTH {
return Err(FrameError::InvalidPayloadLength(plaintext.remaining()));
}
let mut plaintext_frame = BytesMut::new();
plaintext_frame.put(plaintext);
let nonce_bytes = self.encoder.nonce.next()?;
let key = GenericArray::from_slice(&self.encoder.key);
let cipher = XSalsa20Poly1305::new(key);
let nonce = GenericArray::from_slice(&nonce_bytes);
let ciphertext = cipher.encrypt(nonce, plaintext_frame.as_ref())?;
let mut length = ciphertext.len() as u16;
let length_mask: u16 = self.encoder.drbg.length_mask();
debug!(
"encoding➡️ {length}B, {length:04x}^{length_mask:04x} {:04x}",
length ^ length_mask
);
length ^= length_mask;
trace!(
"prng_ciphertext: {}{}",
hex::encode(length.to_be_bytes()),
hex::encode(&ciphertext)
);
dst.extend_from_slice(&length.to_be_bytes()[..]);
dst.extend_from_slice(&ciphertext);
Ok(())
}
}
pub(crate) struct NonceBox {
prefix: [u8; NONCE_PREFIX_LENGTH],
counter: u64,
}
impl NonceBox {
pub fn new(prefix: impl AsRef<[u8]>) -> Self {
assert!(
prefix.as_ref().len() >= NONCE_PREFIX_LENGTH,
"prefix too short: {} < {NONCE_PREFIX_LENGTH}",
prefix.as_ref().len()
);
Self {
prefix: prefix.as_ref()[..NONCE_PREFIX_LENGTH].try_into().unwrap(),
counter: 1,
}
}
pub fn next(&mut self) -> std::result::Result<[u8; NONCE_LENGTH], FrameError> {
if self.counter == u64::MAX {
return Err(FrameError::NonceCounterWrapped);
}
let mut nonce = self.prefix.clone().to_vec();
nonce.append(&mut self.counter.to_be_bytes().to_vec());
let nonce_l: [u8; NONCE_LENGTH] = nonce[..].try_into().unwrap();
trace!("fresh nonce: {}", hex::encode(nonce_l));
self.inc();
Ok(nonce_l)
}
fn inc(&mut self) {
self.counter += 1;
}
}
#[cfg(test)]
mod testing {
use super::*;
use crate::Result;
#[test]
fn nonce_wrap() -> Result<()> {
let mut nb = NonceBox::new([0_u8; NONCE_PREFIX_LENGTH]);
nb.counter = u64::MAX;
assert_eq!(nb.next().unwrap_err(), FrameError::NonceCounterWrapped);
Ok(())
}
}