use std::io::Cursor;
use anyhow::Result;
use aws_lc_rs::{
aead::{Aad, LessSafeKey, Nonce},
error::Unspecified,
hmac::{Key, verify},
};
use bincode_next::{Decode, Encode, config::standard, decode_from_slice};
use tracing::error;
use uuid::Uuid;
use crate::{
MoshpitError, UuidWrapper,
error::Error,
frames::{get_bytes, get_nonce, get_usize},
};
const UUID_LEN: usize = 16;
const AEAD_TAG_LEN: usize = 16;
pub(crate) const MAX_ENCFRAME_LENGTH: usize = 65536;
#[derive(Clone, Debug, Decode, Encode, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub enum EncryptedFrame {
Bytes((UuidWrapper, Vec<u8>)),
Resize((UuidWrapper, u16, u16)),
Nak(Vec<u64>),
Shutdown,
Keepalive(u64),
ScrollbackStart,
ScrollbackEnd,
ScreenState(Vec<u8>),
RepaintRequest,
ScreenStateCompressed(Vec<u8>),
CompressedBytes((UuidWrapper, Vec<u8>)),
StateSyncDiff((u64, u64, Vec<u8>)),
ClientAck(u64),
PtyExit,
StateChunk((u16, u16, Vec<u8>)),
}
impl EncryptedFrame {
#[must_use]
pub fn id(&self) -> u8 {
match self {
EncryptedFrame::Bytes(_) => 0,
EncryptedFrame::Resize(_) => 1,
EncryptedFrame::Nak(_) => 2,
EncryptedFrame::Shutdown => 3,
EncryptedFrame::Keepalive(_) => 4,
EncryptedFrame::ScrollbackStart => 5,
EncryptedFrame::ScrollbackEnd => 6,
EncryptedFrame::ScreenState(_) => 7,
EncryptedFrame::RepaintRequest => 8,
EncryptedFrame::ScreenStateCompressed(_) => 9,
EncryptedFrame::CompressedBytes(_) => 10,
EncryptedFrame::StateSyncDiff(_) => 11,
EncryptedFrame::ClientAck(_) => 12,
EncryptedFrame::PtyExit => 13,
EncryptedFrame::StateChunk(_) => 14,
}
}
pub fn parse(
src: &mut Cursor<&[u8]>,
id: Uuid,
hmac: &Key,
rnk: &LessSafeKey,
mac_tag_len: usize,
) -> Result<Option<(Self, u64)>> {
let Some(nonce_bytes) = get_nonce(src)? else {
return Ok(None);
};
let Some(seq_bytes) = get_usize(src)? else {
return Ok(None);
};
let seq = u64::from_be_bytes(seq_bytes.try_into()?);
if let Some(tag_bytes) = get_bytes(src, mac_tag_len)?
&& let Some(length_slice) = get_usize(src)?
{
let length = usize::from_be_bytes(length_slice.try_into()?);
if length > MAX_ENCFRAME_LENGTH {
return Err(Error::FrameTooLarge.into());
}
if let Some(data) = get_bytes(src, length)? {
let mut to_verify = seq_bytes.to_vec();
to_verify.extend_from_slice(data);
if let Ok(()) = verify(hmac, &to_verify, tag_bytes) {
let mut data = data.to_vec();
let nonce = Nonce::try_assume_unique_for_key(nonce_bytes)?;
let aad = Aad::from(seq.to_be_bytes());
let _ = rnk.open_in_place(nonce, aad, &mut data)?;
let (uuid_bytes, rest) = data.split_at(UUID_LEN);
let uuid = Uuid::from_bytes(uuid_bytes.try_into()?);
if uuid != id {
error!("UUID mismatch: expected {id}, got {uuid}");
return Err(MoshpitError::UuidMismatch.into());
}
let mut message_with_tag = rest.to_vec();
message_with_tag.reverse();
let mut message = message_with_tag.split_off(AEAD_TAG_LEN);
message.reverse();
let config = standard().with_limit::<65536>();
let frame_data: (EncryptedFrame, _) = decode_from_slice(&message, config)?;
return Ok(Some((frame_data.0, seq)));
}
error!("HMAC verification failed");
return Err(Unspecified.into());
}
}
Ok(None)
}
}
#[cfg(test)]
mod tests {
use std::io::Cursor;
use aws_lc_rs::{
aead::{AES_256_GCM_SIV, Aad, LessSafeKey, NONCE_LEN, UnboundKey},
hmac::{HMAC_SHA512, Key, sign},
rand,
};
use bincode_next::{config::standard, encode_to_vec};
use uuid::Uuid;
use crate::UuidWrapper;
use super::EncryptedFrame;
fn make_keys() -> (Uuid, LessSafeKey, Key) {
let id = Uuid::new_v4();
let rnk = LessSafeKey::new(UnboundKey::new(&AES_256_GCM_SIV, &[1u8; 32]).unwrap());
let hmac = Key::new(HMAC_SHA512, &[2u8; 64]);
(id, rnk, hmac)
}
fn encrypt_frame(
frame: &EncryptedFrame,
seq: u64,
id: Uuid,
rnk: &LessSafeKey,
hmac: &Key,
) -> Vec<u8> {
let data = encode_to_vec(frame, standard()).unwrap();
let aad = Aad::from(seq.to_be_bytes());
let mut encrypted_part = id.as_bytes().to_vec();
encrypted_part.extend_from_slice(&data);
let mut nonce_bytes = [0u8; NONCE_LEN];
rand::fill(&mut nonce_bytes).unwrap();
let nonce = aws_lc_rs::aead::Nonce::try_assume_unique_for_key(&nonce_bytes).unwrap();
rnk.seal_in_place_append_tag(nonce, aad, &mut encrypted_part)
.unwrap();
let seq_bytes = seq.to_be_bytes();
let mut to_sign = seq_bytes.to_vec();
to_sign.extend_from_slice(&encrypted_part);
let tag = sign(hmac, &to_sign);
let tag_bytes: [u8; 64] = tag.as_ref().try_into().unwrap();
let len = encrypted_part.len().to_be_bytes();
let mut packet = nonce_bytes.to_vec();
packet.extend_from_slice(&seq_bytes);
packet.extend_from_slice(&tag_bytes);
packet.extend_from_slice(&len);
packet.extend_from_slice(&encrypted_part);
packet
}
#[test]
fn frame_id_variants_are_correct() {
let uuid = Uuid::new_v4();
assert_eq!(
EncryptedFrame::Bytes((UuidWrapper::new(uuid), vec![])).id(),
0
);
assert_eq!(
EncryptedFrame::Resize((UuidWrapper::new(uuid), 0, 0)).id(),
1
);
assert_eq!(EncryptedFrame::Nak(vec![]).id(), 2);
assert_eq!(EncryptedFrame::Shutdown.id(), 3);
assert_eq!(EncryptedFrame::Keepalive(0).id(), 4);
assert_eq!(EncryptedFrame::ScrollbackStart.id(), 5);
assert_eq!(EncryptedFrame::ScrollbackEnd.id(), 6);
assert_eq!(EncryptedFrame::ScreenState(vec![]).id(), 7);
assert_eq!(EncryptedFrame::RepaintRequest.id(), 8);
assert_eq!(EncryptedFrame::ScreenStateCompressed(vec![]).id(), 9);
assert_eq!(
EncryptedFrame::CompressedBytes((UuidWrapper::new(uuid), vec![])).id(),
10
);
assert_eq!(EncryptedFrame::StateSyncDiff((0, 0, vec![])).id(), 11);
assert_eq!(EncryptedFrame::ClientAck(0).id(), 12);
assert_eq!(EncryptedFrame::PtyExit.id(), 13);
assert_eq!(EncryptedFrame::StateChunk((0, 1, vec![])).id(), 14);
}
#[test]
fn parse_round_trip_keepalive() {
let (id, rnk, hmac) = make_keys();
let ts = 1_234_567_890_u64;
let packet = encrypt_frame(&EncryptedFrame::Keepalive(ts), 0, id, &rnk, &hmac);
let mut cursor = Cursor::new(packet.as_slice());
let (parsed_frame, seq) = EncryptedFrame::parse(&mut cursor, id, &hmac, &rnk, 64)
.unwrap()
.unwrap();
assert_eq!(parsed_frame, EncryptedFrame::Keepalive(ts));
assert_eq!(seq, 0);
}
#[test]
fn parse_round_trip_all_aead_algorithms_separate_key_instances() {
use aws_lc_rs::aead::{AES_128_GCM_SIV, AES_256_GCM, CHACHA20_POLY1305};
let algorithms: &[(&aws_lc_rs::aead::Algorithm, &[u8])] = &[
(&AES_256_GCM_SIV, &[1u8; 32]),
(&AES_256_GCM, &[2u8; 32]),
(&CHACHA20_POLY1305, &[3u8; 32]),
(&AES_128_GCM_SIV, &[4u8; 16]),
];
for (alg, key_bytes) in algorithms {
eprintln!("testing alg={alg:?} key_len={}", key_bytes.len());
let id = Uuid::new_v4();
let hmac = Key::new(HMAC_SHA512, &[5u8; 64]);
let enc_key = LessSafeKey::new(
UnboundKey::new(alg, key_bytes)
.unwrap_or_else(|e| panic!("enc_key creation failed for {alg:?}: {e:?}")),
);
let dec_key = LessSafeKey::new(
UnboundKey::new(alg, key_bytes)
.unwrap_or_else(|e| panic!("dec_key creation failed for {alg:?}: {e:?}")),
);
let ts = 42_u64;
let packet = encrypt_frame(&EncryptedFrame::Keepalive(ts), 7, id, &enc_key, &hmac);
let mut cursor = Cursor::new(packet.as_slice());
let result = EncryptedFrame::parse(&mut cursor, id, &hmac, &dec_key, 64);
let (parsed_frame, seq) = match result {
Ok(Some(inner)) => inner,
Ok(None) => panic!("parse returned None for algorithm {alg:?}"),
Err(e) => panic!("parse failed for algorithm {alg:?}: {e}"),
};
assert_eq!(
parsed_frame,
EncryptedFrame::Keepalive(ts),
"wrong frame for {alg:?}"
);
assert_eq!(seq, 7, "wrong seq for {alg:?}");
}
}
#[test]
fn parse_round_trip_shutdown() {
let (id, rnk, hmac) = make_keys();
let packet = encrypt_frame(&EncryptedFrame::Shutdown, 42, id, &rnk, &hmac);
let mut cursor = Cursor::new(packet.as_slice());
let (parsed_frame, seq) = EncryptedFrame::parse(&mut cursor, id, &hmac, &rnk, 64)
.unwrap()
.unwrap();
assert_eq!(parsed_frame, EncryptedFrame::Shutdown);
assert_eq!(seq, 42);
}
#[test]
fn parse_truncated_returns_none() {
let (id, rnk, hmac) = make_keys();
let packet = [0u8; 4];
let mut cursor = Cursor::new(packet.as_slice());
let result = EncryptedFrame::parse(&mut cursor, id, &hmac, &rnk, 64).unwrap();
assert!(result.is_none());
}
#[test]
fn parse_wrong_uuid_returns_error() {
let (id, rnk, hmac) = make_keys();
let packet = encrypt_frame(&EncryptedFrame::Keepalive(0), 0, id, &rnk, &hmac);
let wrong_id = Uuid::new_v4();
let mut cursor = Cursor::new(packet.as_slice());
let result = EncryptedFrame::parse(&mut cursor, wrong_id, &hmac, &rnk, 64);
assert!(result.is_err());
}
#[test]
fn test_parse_oversized_encframe() {
use crate::frames::encframe::MAX_ENCFRAME_LENGTH;
let (id, rnk, hmac) = make_keys();
let oversized_len = MAX_ENCFRAME_LENGTH + 1;
let seq = 0u64;
let aad = Aad::from(seq.to_be_bytes());
let mut encrypted_part = id.as_bytes().to_vec();
encrypted_part.extend_from_slice(&[0u8; 10]);
let mut nonce_bytes = [0u8; NONCE_LEN];
rand::fill(&mut nonce_bytes).unwrap();
let nonce = aws_lc_rs::aead::Nonce::try_assume_unique_for_key(&nonce_bytes).unwrap();
rnk.seal_in_place_append_tag(nonce, aad, &mut encrypted_part)
.unwrap();
let seq_bytes = seq.to_be_bytes();
let mut to_sign = seq_bytes.to_vec();
to_sign.extend_from_slice(&encrypted_part);
let tag = sign(&hmac, &to_sign);
let tag_bytes: [u8; 64] = tag.as_ref().try_into().unwrap();
let len = oversized_len.to_be_bytes();
let mut packet = nonce_bytes.to_vec();
packet.extend_from_slice(&seq_bytes);
packet.extend_from_slice(&tag_bytes);
packet.extend_from_slice(&len);
packet.extend_from_slice(&encrypted_part);
let mut cursor = Cursor::new(packet.as_slice());
let result = EncryptedFrame::parse(&mut cursor, id, &hmac, &rnk, 64);
assert!(result.is_err());
assert_eq!(
result.unwrap_err().to_string(),
crate::error::Error::FrameTooLarge.to_string()
);
}
}