use crate::constants::HDP_HEADER_BYTE_LEN;
use crate::error::NetworkError;
use crate::proto::misc::dual_cell::DualCell;
use byteorder::WriteBytesExt;
use bytes::{Buf, BufMut, Bytes, BytesMut};
use citadel_io as rand;
use citadel_io::RngCore;
use citadel_types::crypto::HeaderObfuscatorSettings;
use rand::Rng;
use rand::ThreadRng;
use sha3::Digest;
use std::net::SocketAddr;
use std::num::NonZero;
use zerocopy::byteorder::big_endian::{I64, U128, U32, U64};
use zerocopy::BigEndian;
use zerocopy::{AsBytes, FromBytes, FromZeroes, Ref, Unaligned};
pub(crate) mod packet_flags {
pub(crate) mod cmd {
pub(crate) mod primary {
pub(crate) const KEEP_ALIVE: u8 = 0;
pub(crate) const DO_CONNECT: u8 = 1;
pub(crate) const GROUP_PACKET: u8 = 2;
pub(crate) const DO_REGISTER: u8 = 3;
pub(crate) const DO_DISCONNECT: u8 = 4;
pub(crate) const DO_DEREGISTER: u8 = 5;
pub(crate) const DO_PRE_CONNECT: u8 = 6;
pub(crate) const PEER_CMD: u8 = 7;
pub(crate) const FILE: u8 = 8;
pub(crate) const UDP: u8 = 9;
pub(crate) const HOLE_PUNCH: u8 = 10;
}
pub(crate) mod aux {
pub(crate) mod group {
pub(crate) const GROUP_HEADER: u8 = 0;
pub(crate) const GROUP_HEADER_ACK: u8 = 1;
pub(crate) const GROUP_PAYLOAD: u8 = 2;
pub(crate) const WAVE_ACK: u8 = 3;
}
pub(crate) mod do_connect {
pub(crate) const STAGE0: u8 = 0;
pub(crate) const STAGE1: u8 = 1;
pub(crate) const SUCCESS: u8 = 3;
pub(crate) const FAILURE: u8 = 4;
pub(crate) const SUCCESS_ACK: u8 = 5;
}
pub(crate) mod do_register {
pub(crate) const STAGE0: u8 = 0;
pub(crate) const STAGE1: u8 = 1;
pub(crate) const STAGE2: u8 = 2;
pub(crate) const SUCCESS: u8 = 5;
pub(crate) const FAILURE: u8 = 6;
}
pub(crate) mod do_disconnect {
pub(crate) const STAGE0: u8 = 0;
pub(crate) const FINAL: u8 = 1;
}
pub(crate) mod do_deregister {
pub(crate) const STAGE0: u8 = 0;
pub(crate) const SUCCESS: u8 = 3;
pub(crate) const FAILURE: u8 = 4;
}
pub(crate) mod do_preconnect {
pub(crate) const SYN: u8 = 0;
pub(crate) const SYN_ACK: u8 = 1;
pub(crate) const STAGE0: u8 = 2;
pub(crate) const SUCCESS: u8 = 6;
pub(crate) const FAILURE: u8 = 7;
pub(crate) const BEGIN_CONNECT: u8 = 8;
pub(crate) const HALT: u8 = 10;
}
pub(crate) mod peer_cmd {
pub(crate) const SIGNAL: u8 = 0;
pub(crate) const CHANNEL: u8 = 1;
pub(crate) const GROUP_BROADCAST: u8 = 2;
}
pub(crate) mod file {
pub(crate) const FILE_HEADER: u8 = 0;
pub(crate) const FILE_HEADER_ACK: u8 = 1;
pub(crate) const REVFS_PULL: u8 = 2;
pub(crate) const REVFS_DELETE: u8 = 3;
pub(crate) const REVFS_ACK: u8 = 4;
pub(crate) const REVFS_PULL_ACK: u8 = 5;
pub(crate) const FILE_ERROR: u8 = 6;
}
pub(crate) mod udp {
pub(crate) const STREAM: u8 = 0;
pub(crate) const KEEP_ALIVE: u8 = 1;
pub(crate) const HOLE_PUNCH: u8 = 2;
}
}
}
pub(crate) mod payload_identifiers {
pub(crate) mod do_preconnect {
pub(crate) const TCP_ONLY: u8 = 1;
}
}
}
pub(crate) mod packet_sizes {
use crate::constants::HDP_HEADER_BYTE_LEN;
pub(crate) const GROUP_HEADER_BASE_LEN: usize = HDP_HEADER_BYTE_LEN + 1;
pub(crate) const GROUP_HEADER_ACK_LEN: usize = HDP_HEADER_BYTE_LEN + 1 + 1 + 4 + 4;
}
#[derive(Debug, FromZeroes, AsBytes, FromBytes, Unaligned, Clone)]
#[repr(C)]
pub struct HdpHeader {
pub cmd_primary: u8,
pub cmd_aux: u8,
pub algorithm: u8,
pub security_level: u8,
pub protocol_version: U32,
pub context_info: U128,
pub group: U64,
pub wave_id: U32,
pub session_cid: U64,
pub entropy_bank_version: U32,
pub timestamp: I64,
pub target_cid: U64,
}
impl AsRef<[u8]> for HdpHeader {
fn as_ref(&self) -> &[u8] {
self.as_bytes()
}
}
impl HdpHeader {
pub fn inscribe_into<B: BufMut>(&self, mut writer: B) {
writer.put_slice(self.as_bytes())
}
pub fn as_packet(&self) -> BytesMut {
BytesMut::from(self.as_bytes())
}
}
pub struct HdpPacket<B: HdpBuffer = BytesMut> {
packet: B,
remote_peer: SocketAddr,
local_port: u16,
}
pub type ParsedPacket<'a> = (Ref<&'a [u8], HdpHeader>, &'a [u8]);
impl<B: HdpBuffer> HdpPacket<B> {
pub fn new_recv(packet: B, remote_peer: SocketAddr, local_port: u16) -> Self {
Self {
packet,
remote_peer,
local_port,
}
}
pub(crate) fn as_bytes(&self) -> &[u8] {
self.packet.as_ref()
}
pub fn parse(&self) -> Option<ParsedPacket> {
Ref::new_from_prefix(self.packet.as_ref())
}
pub fn into_packet(self) -> B {
self.packet
}
pub fn get_length(&self) -> usize {
self.packet.len()
}
pub fn decompose(mut self) -> (B::Immutable, B, SocketAddr, u16) {
let header_bytes = self.packet.split_to(HDP_HEADER_BYTE_LEN).to_immutable();
let payload_bytes = self.packet;
let remote_peer = self.remote_peer;
let local_port = self.local_port;
(header_bytes, payload_bytes, remote_peer, local_port)
}
}
#[derive(Clone)]
pub struct HeaderObfuscator {
inner: DualCell<Option<NonZero<u128>>>,
pub first_packet: Option<BytesMut>,
expected_key: Option<NonZero<u128>>,
disabled: DualCell<bool>,
client_intends_disable: DualCell<bool>,
}
const DISABLED_KEY: u128 = u128::MAX;
impl HeaderObfuscator {
pub fn new(is_server: bool, header_obfuscator_settings: HeaderObfuscatorSettings) -> Self {
if is_server {
Self::new_server(header_obfuscator_settings)
} else {
Self::new_client(header_obfuscator_settings)
}
}
pub fn on_packet_received(&self, packet: &mut BytesMut) -> Result<bool, NetworkError> {
if self.is_disabled() {
return Ok(true);
}
if let Some(val) = self.load() {
if packet.len() < HDP_HEADER_BYTE_LEN {
log::warn!(target: "citadel", "[Header Obfuscator] Packet too small: {}", packet.len());
return Ok(false);
}
log::trace!(target: "citadel", "[Header Obfuscator] Applying inbound cipher w/key {val}");
apply_cipher(val, true, packet);
Ok(true)
} else if packet.len() >= 16 {
let key = packet.get_u128();
if key == 0 {
log::error!(target: "citadel", "[Header Obfuscator] Invalid first packet key == 0");
return Err(NetworkError::msg("Invalid first packet key"));
}
if let Some(expected_key) = self.expected_key {
if key != expected_key.get() {
log::error!(target: "citadel", "[Header Obfuscator] Invalid first packet key {key} != {expected_key}");
return Err(NetworkError::msg("Invalid first packet key"));
}
}
if key == DISABLED_KEY {
log::trace!(target: "citadel", "[Header Obfuscator] Disabling obfuscator at client's request");
self.disabled.set(true);
self.client_intends_disable.set(true);
return Ok(false);
}
self.store(key);
log::trace!(target: "citadel", "[Header Obfuscator] initial packet set to {key}");
Ok(false)
} else {
log::warn!(target: "citadel", "[Header Obfuscator] Packet too small (skipping): {}", packet.len());
Ok(false)
}
}
pub fn prepare_outbound(&self, mut packet: BytesMut) -> Bytes {
if self.client_intends_disable.get() && self.disabled.get() {
return packet.freeze();
}
if let Some(key) = self.load() {
if packet.len() >= HDP_HEADER_BYTE_LEN {
log::trace!(target: "citadel", "[Header Obfuscator] Applying outbound cipher w/key {key}");
apply_cipher(key, false, &mut packet);
if self.client_intends_disable.get() {
self.disabled.set(true);
}
}
}
packet.freeze()
}
pub fn new_client(header_obfuscator_settings: HeaderObfuscatorSettings) -> Self {
let key = match header_obfuscator_settings {
HeaderObfuscatorSettings::Enabled => rand::random::<u128>(),
HeaderObfuscatorSettings::Disabled => {
let mut disabled_packet = BytesMut::with_capacity(16);
disabled_packet.put_u128(DISABLED_KEY);
return Self {
inner: None.into(),
first_packet: Some(disabled_packet),
expected_key: None,
disabled: true.into(),
client_intends_disable: false.into(),
};
}
HeaderObfuscatorSettings::EnabledWithKey(key) => key,
};
let key = hash_u128(key);
let mut rng = ThreadRng::default();
let bytes_to_add = rng.gen_range(0..(HDP_HEADER_BYTE_LEN - 17));
let mut packet = vec![0; 16 + bytes_to_add];
let tmp = &mut packet[..];
let mut tmp = tmp.writer();
tmp.write_u128::<BigEndian>(key).expect("Should not fail");
rng.fill_bytes(&mut packet[16..]);
let first_packet = Some(BytesMut::from(&packet[..]));
Self {
inner: DualCell::from(Some(NonZero::new(key).expect("Hashed key cannot be zero"))),
first_packet,
expected_key: None,
disabled: false.into(),
client_intends_disable: false.into(),
}
}
pub fn new_server(header_obfuscator_settings: HeaderObfuscatorSettings) -> Self {
let (inner, expected_key) = match header_obfuscator_settings {
HeaderObfuscatorSettings::Enabled => (DualCell::from(None), None), HeaderObfuscatorSettings::Disabled => (DualCell::from(None), None), HeaderObfuscatorSettings::EnabledWithKey(key) => {
let key = NonZero::new(hash_u128(key)).expect("Hashed key cannot be zero");
(DualCell::from(Some(key)), Some(key))
}
};
Self {
inner,
first_packet: None,
expected_key,
disabled: false.into(), client_intends_disable: false.into(), }
}
fn store(&self, key: u128) {
let key = NonZero::new(key).expect("Input key cannot be zero");
self.inner.set(Some(key));
}
fn load(&self) -> Option<u128> {
Some(self.inner.get()?.get())
}
fn is_disabled(&self) -> bool {
self.disabled.get()
}
}
fn hash_u128(key: u128) -> u128 {
let mut hasher = sha3::Sha3_256::default();
hasher.update(key.to_be_bytes());
let out: [u8; 32] = hasher.finalize().into();
let slice: [u8; 16] = out[0..16].try_into().unwrap();
u128::from_be_bytes(slice)
}
#[inline]
fn apply_cipher(val: u128, inverse: bool, packet: &mut BytesMut) {
let bytes = val.to_be_bytes();
let (bytes0, bytes1) = bytes.split_at(8);
let packet_len = packet.len().min(HDP_HEADER_BYTE_LEN);
let packet = &mut packet[..packet_len];
bytes0
.iter()
.zip(bytes1.iter())
.cycle()
.zip(packet.iter_mut())
.for_each(|((a, b), c)| cipher_inner(*a, *b, c, inverse))
}
#[inline]
fn cipher_inner(a: u8, b: u8, c: &mut u8, inverse: bool) {
if inverse {
*c = (*c ^ b).wrapping_sub(a);
} else {
*c = c.wrapping_add(a) ^ b;
}
}
pub trait HdpBuffer: BufMut + AsRef<[u8]> + AsMut<[u8]> {
type Immutable;
fn len(&self) -> usize;
fn split_to(&mut self, idx: usize) -> Self;
fn to_immutable(self) -> Self::Immutable;
}
impl HdpBuffer for BytesMut {
type Immutable = Bytes;
fn len(&self) -> usize {
self.len()
}
fn split_to(&mut self, idx: usize) -> Self {
self.split_to(idx)
}
fn to_immutable(self) -> Self::Immutable {
self.freeze()
}
}
impl HdpBuffer for Vec<u8> {
type Immutable = Vec<u8>;
fn len(&self) -> usize {
self.len()
}
fn split_to(&mut self, idx: usize) -> Self {
let tail = self[..idx].to_vec();
self.copy_within(idx.., 0);
self.truncate(self.len() - idx);
tail }
fn to_immutable(self) -> Self::Immutable {
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::BytesMut;
use citadel_types::crypto::HeaderObfuscatorSettings;
#[test]
fn test_header_obfuscator_client_server_interaction() {
let client = HeaderObfuscator::new_client(HeaderObfuscatorSettings::Enabled);
assert!(
client.first_packet.is_some(),
"Client should have initial packet"
);
assert!(
client.expected_key.is_none(),
"Client should not have expected key initially"
);
let server = HeaderObfuscator::new_server(HeaderObfuscatorSettings::Enabled);
assert!(
server.first_packet.is_none(),
"Server should not have initial packet"
);
assert!(
server.expected_key.is_none(),
"Server should not have expected key initially"
);
}
#[test]
fn test_header_obfuscator_key_exchange() {
let client = HeaderObfuscator::new_client(HeaderObfuscatorSettings::Enabled);
let server = HeaderObfuscator::new_server(HeaderObfuscatorSettings::Enabled);
let mut first_packet = client.first_packet.as_ref().unwrap().clone();
assert!(server.on_packet_received(&mut first_packet).is_ok());
assert_eq!(server.load(), client.load());
assert!(server.load().is_some(), "Both should have non-None key");
let mut test_packet = BytesMut::with_capacity(HDP_HEADER_BYTE_LEN);
test_packet.resize(HDP_HEADER_BYTE_LEN, 1);
let client_processed = client.prepare_outbound(test_packet.clone());
let mut server_packet = BytesMut::from(&client_processed[..]);
assert!(server.on_packet_received(&mut server_packet).is_ok());
}
#[test]
fn test_header_obfuscator_disabled() {
let client = HeaderObfuscator::new_client(HeaderObfuscatorSettings::Disabled);
let server = HeaderObfuscator::new_server(HeaderObfuscatorSettings::Disabled);
assert!(client.load().is_none());
assert!(server.load().is_none());
assert!(client.first_packet.is_some()); assert!(server.first_packet.is_none());
}
#[test]
fn test_header_obfuscator_small_packet_ignores() {
let server = HeaderObfuscator::new_server(HeaderObfuscatorSettings::EnabledWithKey(12345));
let mut small_packet = BytesMut::with_capacity(16);
small_packet.resize(15, 1);
let initial_small_packet = small_packet.clone();
assert!(
server.on_packet_received(&mut small_packet).is_ok(),
"Packets that are smaller than 16 bytes will just be skipped"
);
assert_eq!(
initial_small_packet, small_packet,
"Packets that are smaller than 16 bytes should not be modified"
);
let mut empty_packet = BytesMut::new();
let initial_empty_packet = empty_packet.clone();
assert!(
server.on_packet_received(&mut empty_packet).is_ok(),
"Empty packets should be skipped"
);
assert_eq!(
initial_empty_packet, empty_packet,
"Empty packets should not be modified"
);
}
#[test]
fn test_header_obfuscator_invalid_keys() {
let server = HeaderObfuscator::new_server(HeaderObfuscatorSettings::EnabledWithKey(12345));
let mut zero_key_packet = BytesMut::with_capacity(16);
zero_key_packet.put_u128(0);
assert_eq!(zero_key_packet.len(), 16);
assert!(
server.on_packet_received(&mut zero_key_packet).is_ok(),
"Should silently ignore packet with zero key"
);
let mut invalid_key_packet = BytesMut::with_capacity(16);
invalid_key_packet.put_u128(54321); assert!(
server.on_packet_received(&mut invalid_key_packet).is_ok(),
"Should ignore packet with mismatched key"
);
}
#[test]
fn test_header_obfuscator_invalid_keys_no_preset_server_value() {
let server = HeaderObfuscator::new_server(HeaderObfuscatorSettings::Enabled);
let mut zero_key_packet = BytesMut::with_capacity(16);
zero_key_packet.put_u128(0);
assert!(
server.on_packet_received(&mut zero_key_packet).is_err(),
"Should error on packet with zero key"
);
assert!(server.load().is_none(), "Server should have no key until the client sends a valid key since the server has no initial key");
let mut good_first_packet = BytesMut::with_capacity(16);
good_first_packet.put_u128(12345);
assert!(
server.on_packet_received(&mut good_first_packet).is_ok(),
"Should accept packet with valid key"
);
let mut invalid_key_packet = BytesMut::with_capacity(16);
invalid_key_packet.put_u128(
server
.load()
.expect("Server should have key")
.wrapping_add(1),
); assert!(
server.on_packet_received(&mut invalid_key_packet).is_ok(),
"Should ignore packet with mismatched key"
);
}
#[test]
fn test_header_obfuscator_disabled_behavior() {
let disabled_server = HeaderObfuscator::new_server(HeaderObfuscatorSettings::Disabled);
let disabled_client = HeaderObfuscator::new_client(HeaderObfuscatorSettings::Disabled);
let mut small_packet = BytesMut::with_capacity(8);
small_packet.resize(8, 1);
let initial_small = small_packet.clone();
let mut full_packet = BytesMut::with_capacity(HDP_HEADER_BYTE_LEN);
full_packet.resize(HDP_HEADER_BYTE_LEN, 2);
let initial_full = full_packet.clone();
assert!(disabled_server
.on_packet_received(&mut small_packet)
.is_ok());
assert!(disabled_client.on_packet_received(&mut full_packet).is_ok());
assert_eq!(
initial_small, small_packet,
"Disabled obfuscator should not modify small packets"
);
assert_eq!(
initial_full, full_packet,
"Disabled obfuscator should not modify full packets"
);
}
#[test]
fn test_header_obfuscator_key_exchange_flow() {
let client = HeaderObfuscator::new_client(HeaderObfuscatorSettings::Enabled);
let server = HeaderObfuscator::new_server(HeaderObfuscatorSettings::Enabled);
let mut first_packet = client.first_packet.as_ref().unwrap().clone();
assert!(server.on_packet_received(&mut first_packet).is_ok());
assert_ne!(
first_packet,
client.first_packet.as_ref().unwrap().clone(),
"First packet should be modified by server"
);
assert_eq!(server.load(), client.load());
assert!(server.load().is_some(), "Both should have non-None key");
let mut test_packet = BytesMut::with_capacity(HDP_HEADER_BYTE_LEN);
test_packet.resize(HDP_HEADER_BYTE_LEN, 3);
let initial_test = test_packet.clone();
let client_processed = client.prepare_outbound(test_packet.clone());
let mut server_packet = BytesMut::from(&client_processed[..]);
assert!(server.on_packet_received(&mut server_packet).is_ok());
assert_eq!(
server_packet, initial_test,
"Server should decrypt to original packet"
);
let server_processed = server.prepare_outbound(test_packet.clone());
let mut client_packet = BytesMut::from(&server_processed[..]);
assert!(client.on_packet_received(&mut client_packet).is_ok());
assert_eq!(
client_packet, initial_test,
"Client should decrypt to original packet"
);
}
#[test]
fn test_header_obfuscator_preshared_key() {
let psk = 12345u128;
let client = HeaderObfuscator::new_client(HeaderObfuscatorSettings::EnabledWithKey(psk));
let server = HeaderObfuscator::new_server(HeaderObfuscatorSettings::EnabledWithKey(psk));
assert_eq!(client.load(), server.load());
assert!(client.load().is_some());
let mut packet = BytesMut::with_capacity(HDP_HEADER_BYTE_LEN);
packet.resize(HDP_HEADER_BYTE_LEN, 1);
let client_processed = client.prepare_outbound(packet.clone());
let mut server_packet = BytesMut::from(&client_processed[..]);
assert!(server.on_packet_received(&mut server_packet).is_ok());
}
#[test]
fn test_header_obfuscator_mismatched_psk() {
let client = HeaderObfuscator::new_client(HeaderObfuscatorSettings::EnabledWithKey(12345));
let server = HeaderObfuscator::new_server(HeaderObfuscatorSettings::EnabledWithKey(54321));
let mut packet = BytesMut::with_capacity(HDP_HEADER_BYTE_LEN);
packet.resize(HDP_HEADER_BYTE_LEN, 0);
let client_processed = client.prepare_outbound(packet.clone());
let server_processed = server.prepare_outbound(packet.clone());
assert_ne!(client_processed[..], server_processed[..]);
}
#[test]
fn test_header_obfuscator_key_validation() {
let mut server =
HeaderObfuscator::new_server(HeaderObfuscatorSettings::EnabledWithKey(12345));
let mut invalid_key_packet = BytesMut::with_capacity(16);
invalid_key_packet.put_u128(54321); assert!(
server.on_packet_received(&mut invalid_key_packet).is_ok(),
"Should silently accept packet with mismatched key"
);
let mut small_valid_key = BytesMut::with_capacity(16);
small_valid_key.put_u128(12345);
assert!(
server.on_packet_received(&mut small_valid_key).is_ok(),
"Should accept packet with valid key even if small"
);
server = HeaderObfuscator::new_server(HeaderObfuscatorSettings::Enabled);
let mut valid_key_packet = BytesMut::with_capacity(16);
valid_key_packet.put_u128(54321);
assert!(
server.on_packet_received(&mut valid_key_packet).is_ok(),
"Should accept any non-zero key when no PSK"
);
}
#[test]
fn test_header_obfuscator_psk_mismatch_modes() {
let client = HeaderObfuscator::new_client(HeaderObfuscatorSettings::EnabledWithKey(12345));
let server = HeaderObfuscator::new_server(HeaderObfuscatorSettings::Enabled);
let mut first_packet = client.first_packet.as_ref().unwrap().clone();
assert!(server.on_packet_received(&mut first_packet).is_ok());
assert!(server.load().is_some());
let client = HeaderObfuscator::new_client(HeaderObfuscatorSettings::Enabled);
let server = HeaderObfuscator::new_server(HeaderObfuscatorSettings::EnabledWithKey(12345));
let mut first_packet = client.first_packet.as_ref().unwrap().clone();
assert!(server.on_packet_received(&mut first_packet).is_ok());
assert_eq!(server.load().unwrap(), hash_u128(12345));
}
}