#![deny(clippy::unwrap_used)]
use std::fmt::{self, Debug};
use std::mem::offset_of;
use std::ops::Deref;
use eyre::{bail, eyre};
use zerocopy::{FromBytes, FromZeros, Immutable, IntoBytes, KnownLayout, Unaligned, little_endian};
use crate::packet::util::size_must_be;
use crate::packet::{CheckedPayload, Packet};
#[derive(FromBytes, IntoBytes, KnownLayout, Unaligned, Immutable)]
#[repr(C, packed)]
pub(crate) struct Wg {
pub packet_type: WgPacketType,
rest: [u8],
}
impl Debug for Wg {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Wg")
.field("packet_type", &self.packet_type)
.finish()
}
}
pub enum WgKind {
HandshakeInit(Packet<WgHandshakeInit>),
HandshakeResp(Packet<WgHandshakeResp>),
CookieReply(Packet<WgCookieReply>),
Data(Packet<WgData>),
}
impl Debug for WgKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::HandshakeInit(_) => f.debug_tuple("HandshakeInit").finish(),
Self::HandshakeResp(_) => f.debug_tuple("HandshakeResp").finish(),
Self::CookieReply(_) => f.debug_tuple("CookieReply").finish(),
Self::Data(_) => f.debug_tuple("Data").finish(),
}
}
}
impl From<Packet<WgHandshakeInit>> for WgKind {
fn from(p: Packet<WgHandshakeInit>) -> Self {
WgKind::HandshakeInit(p)
}
}
impl From<Packet<WgHandshakeResp>> for WgKind {
fn from(p: Packet<WgHandshakeResp>) -> Self {
WgKind::HandshakeResp(p)
}
}
impl From<Packet<WgCookieReply>> for WgKind {
fn from(p: Packet<WgCookieReply>) -> Self {
WgKind::CookieReply(p)
}
}
impl From<Packet<WgData>> for WgKind {
fn from(p: Packet<WgData>) -> Self {
WgKind::Data(p)
}
}
impl From<WgKind> for Packet {
fn from(kind: WgKind) -> Self {
match kind {
WgKind::HandshakeInit(packet) => packet.into(),
WgKind::HandshakeResp(packet) => packet.into(),
WgKind::CookieReply(packet) => packet.into(),
WgKind::Data(packet) => packet.into(),
}
}
}
#[derive(FromBytes, IntoBytes, KnownLayout, Unaligned, Immutable, PartialEq, Eq, Clone, Copy)]
#[repr(transparent)]
pub struct WgPacketType(pub u8);
impl WgPacketType {
#![allow(non_upper_case_globals)]
pub const HandshakeInit: WgPacketType = WgPacketType(1);
pub const HandshakeResp: WgPacketType = WgPacketType(2);
pub const CookieReply: WgPacketType = WgPacketType(3);
pub const Data: WgPacketType = WgPacketType(4);
}
#[derive(FromBytes, IntoBytes, KnownLayout, Unaligned, Immutable)]
#[repr(C)]
pub struct WgDataHeader {
packet_type: WgPacketType,
_reserved_zeros: [u8; 4 - size_of::<WgPacketType>()],
pub receiver_idx: little_endian::U32,
pub counter: little_endian::U64,
}
impl WgDataHeader {
pub const LEN: usize = size_must_be::<Self>(16);
}
#[derive(FromBytes, IntoBytes, KnownLayout, Unaligned, Immutable)]
#[repr(C, packed)]
pub struct WgData {
pub header: WgDataHeader,
pub encrypted_encapsulated_packet_and_tag: WgDataAndTag,
}
#[derive(FromBytes, IntoBytes, KnownLayout, Unaligned, Immutable)]
#[repr(C)]
pub struct WgDataAndTag {
_tag_size: [u8; WgData::TAG_LEN],
_extra: [u8],
}
#[derive(Clone, Copy, FromBytes, IntoBytes, KnownLayout, Unaligned, Immutable, PartialEq, Eq)]
#[repr(C)]
pub struct EncryptedWithTag<T: Sized> {
pub encrypted: T,
pub tag: [u8; 16],
}
impl WgData {
pub const OVERHEAD: usize = WgDataHeader::LEN + WgData::TAG_LEN;
pub const TAG_LEN: usize = 16;
fn split_encapsulated_packet_and_tag(&self) -> (&[u8], &[u8; WgData::TAG_LEN]) {
self.encrypted_encapsulated_packet_and_tag
.split_last_chunk::<{ WgData::TAG_LEN }>()
.expect("WgDataAndTag is at least TAG_LEN bytes long")
}
fn split_encapsulated_packet_and_tag_mut(&mut self) -> (&mut [u8], &mut [u8; WgData::TAG_LEN]) {
self.encrypted_encapsulated_packet_and_tag
.split_last_chunk_mut::<{ WgData::TAG_LEN }>()
.expect("WgDataAndTag is at least TAG_LEN bytes long")
}
pub fn encrypted_encapsulated_packet(&self) -> &[u8] {
let (encrypted_encapsulated_packet, _) = self.split_encapsulated_packet_and_tag();
encrypted_encapsulated_packet
}
pub fn encrypted_encapsulated_packet_mut(&mut self) -> &mut [u8] {
let (encrypted_encapsulated_packet, _) = self.split_encapsulated_packet_and_tag_mut();
encrypted_encapsulated_packet
}
pub fn tag(&mut self) -> &[u8; WgData::TAG_LEN] {
let (_, tag) = self.split_encapsulated_packet_and_tag();
tag
}
pub fn tag_mut(&mut self) -> &mut [u8; WgData::TAG_LEN] {
let (_, tag) = self.split_encapsulated_packet_and_tag_mut();
tag
}
pub const fn is_empty(&self) -> bool {
self.encrypted_encapsulated_packet_and_tag._extra.is_empty()
}
pub const fn is_keepalive(&self) -> bool {
self.is_empty()
}
}
impl WgDataHeader {
pub fn new() -> Self {
Self {
packet_type: WgPacketType::Data,
..WgDataHeader::new_zeroed()
}
}
pub const fn with_receiver_idx(mut self, receiver_idx: u32) -> Self {
self.receiver_idx = little_endian::U32::new(receiver_idx);
self
}
pub const fn with_counter(mut self, counter: u64) -> Self {
self.counter = little_endian::U64::new(counter);
self
}
}
impl Default for WgDataHeader {
fn default() -> Self {
Self::new()
}
}
impl Deref for WgDataAndTag {
type Target = [u8];
fn deref(&self) -> &Self::Target {
self.as_bytes()
}
}
impl std::ops::DerefMut for WgDataAndTag {
fn deref_mut(&mut self) -> &mut Self::Target {
self.as_mut_bytes()
}
}
pub trait WgHandshakeBase:
FromBytes + IntoBytes + KnownLayout + Unaligned + Immutable + CheckedPayload
{
const LEN: usize;
const MAC1_OFF: usize;
const MAC2_OFF: usize;
fn sender_idx(&self) -> u32;
fn mac1_mut(&mut self) -> &mut [u8; 16];
fn mac2_mut(&mut self) -> &mut [u8; 16];
fn mac1(&self) -> &[u8; 16];
fn mac2(&self) -> &[u8; 16];
#[inline(always)]
fn until_mac1(&self) -> &[u8] {
&self.as_bytes()[..Self::MAC1_OFF]
}
#[inline(always)]
fn until_mac2(&self) -> &[u8] {
&self.as_bytes()[..Self::MAC2_OFF]
}
}
#[derive(FromBytes, IntoBytes, KnownLayout, Unaligned, Immutable)]
#[repr(C, packed)]
pub struct WgHandshakeInit {
packet_type: WgPacketType,
_reserved_zeros: [u8; 4 - size_of::<WgPacketType>()],
pub sender_idx: little_endian::U32,
pub unencrypted_ephemeral: [u8; 32],
pub encrypted_static: EncryptedWithTag<[u8; 32]>,
pub timestamp: EncryptedWithTag<[u8; 12]>,
pub mac1: [u8; 16],
pub mac2: [u8; 16],
}
impl WgHandshakeInit {
pub const LEN: usize = size_must_be::<Self>(148);
pub fn new() -> Self {
Self {
packet_type: WgPacketType::HandshakeInit,
..WgHandshakeInit::new_zeroed()
}
}
}
impl WgHandshakeBase for WgHandshakeInit {
const LEN: usize = Self::LEN;
const MAC1_OFF: usize = offset_of!(Self, mac1);
const MAC2_OFF: usize = offset_of!(Self, mac2);
fn sender_idx(&self) -> u32 {
self.sender_idx.get()
}
fn mac1_mut(&mut self) -> &mut [u8; 16] {
&mut self.mac1
}
fn mac2_mut(&mut self) -> &mut [u8; 16] {
&mut self.mac2
}
fn mac1(&self) -> &[u8; 16] {
&self.mac1
}
fn mac2(&self) -> &[u8; 16] {
&self.mac2
}
}
impl Default for WgHandshakeInit {
fn default() -> Self {
Self::new()
}
}
#[derive(FromBytes, IntoBytes, KnownLayout, Unaligned, Immutable)]
#[repr(C, packed)]
pub struct WgHandshakeResp {
packet_type: WgPacketType,
_reserved_zeros: [u8; 4 - size_of::<WgPacketType>()],
pub sender_idx: little_endian::U32,
pub receiver_idx: little_endian::U32,
pub unencrypted_ephemeral: [u8; 32],
pub encrypted_nothing: EncryptedWithTag<()>,
pub mac1: [u8; 16],
pub mac2: [u8; 16],
}
impl WgHandshakeResp {
pub const LEN: usize = size_must_be::<Self>(92);
pub fn new(sender_idx: u32, receiver_idx: u32, unencrypted_ephemeral: [u8; 32]) -> Self {
Self {
packet_type: WgPacketType::HandshakeResp,
_reserved_zeros: [0; 3],
sender_idx: sender_idx.into(),
receiver_idx: receiver_idx.into(),
unencrypted_ephemeral,
encrypted_nothing: EncryptedWithTag::new_zeroed(),
mac1: [0u8; 16],
mac2: [0u8; 16],
}
}
}
impl WgHandshakeBase for WgHandshakeResp {
const LEN: usize = Self::LEN;
const MAC1_OFF: usize = offset_of!(Self, mac1);
const MAC2_OFF: usize = offset_of!(Self, mac2);
fn sender_idx(&self) -> u32 {
self.sender_idx.get()
}
fn mac1_mut(&mut self) -> &mut [u8; 16] {
&mut self.mac1
}
fn mac2_mut(&mut self) -> &mut [u8; 16] {
&mut self.mac2
}
fn mac1(&self) -> &[u8; 16] {
&self.mac1
}
fn mac2(&self) -> &[u8; 16] {
&self.mac2
}
}
#[derive(FromBytes, IntoBytes, KnownLayout, Unaligned, Immutable)]
#[repr(C, packed)]
pub struct WgCookieReply {
packet_type: WgPacketType,
_reserved_zeros: [u8; 4 - size_of::<WgPacketType>()],
pub receiver_idx: little_endian::U32,
pub nonce: [u8; 24],
pub encrypted_cookie: EncryptedWithTag<[u8; 16]>,
}
impl WgCookieReply {
pub const LEN: usize = size_must_be::<Self>(64);
pub fn new() -> Self {
Self {
packet_type: WgPacketType::CookieReply,
..Self::new_zeroed()
}
}
}
impl Default for WgCookieReply {
fn default() -> Self {
Self::new()
}
}
impl Packet {
pub fn try_into_wg(self) -> eyre::Result<WgKind> {
let wg = Wg::ref_from_bytes(self.as_bytes())
.map_err(|_| eyre!("Not a wireguard packet, too small."))?;
let len = wg.as_bytes().len();
match (wg.packet_type, len) {
(WgPacketType::HandshakeInit, WgHandshakeInit::LEN) => {
Ok(WgKind::HandshakeInit(self.cast()))
}
(WgPacketType::HandshakeResp, WgHandshakeResp::LEN) => {
Ok(WgKind::HandshakeResp(self.cast()))
}
(WgPacketType::CookieReply, WgCookieReply::LEN) => Ok(WgKind::CookieReply(self.cast())),
(WgPacketType::Data, WgData::OVERHEAD..) => Ok(WgKind::Data(self.cast())),
_ => bail!("Not a wireguard packet, bad type/size."),
}
}
}
impl Debug for WgPacketType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let name = match self {
&WgPacketType::HandshakeInit => "HandshakeInit",
&WgPacketType::HandshakeResp => "HandshakeResp",
&WgPacketType::CookieReply => "CookieReply",
&WgPacketType::Data => "Data",
WgPacketType(t) => return Debug::fmt(t, f),
};
f.debug_tuple(name).finish()
}
}