wiretun 0.5.0

WireGuard Library
Documentation
use std::fmt::{Debug, Formatter};

const MESSAGE_TYPE_HANDSHAKE_INITIATION: u8 = 1u8;
const MESSAGE_TYPE_HANDSHAKE_RESPONSE: u8 = 2u8;
const MESSAGE_TYPE_COOKIE_REPLY: u8 = 3u8;
const MESSAGE_TYPE_TRANSPORT_DATA: u8 = 4u8;
pub const HANDSHAKE_INITIATION_PACKET_SIZE: usize = 148;
pub const HANDSHAKE_RESPONSE_PACKET_SIZE: usize = 92;
pub const COOKIE_REPLY_PACKET_SIZE: usize = 64;

pub const REJECT_AFTER_MESSAGES: u64 = u64::MAX - (1 << 13);

const MIN_PACKET_SIZE: usize = 4; // TODO

use super::Error;

pub struct HandshakeInitiation {
    pub sender_index: u32,
    pub ephemeral_public_key: [u8; 32],
    pub static_public_key: [u8; 32 + 16],
    pub timestamp: [u8; 12 + 16],
    pub mac1: [u8; 16],
    pub mac2: [u8; 16],
}

impl TryFrom<&[u8]> for HandshakeInitiation {
    type Error = Error;

    fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
        if value.len() != HANDSHAKE_INITIATION_PACKET_SIZE
            || value[0..4] != [MESSAGE_TYPE_HANDSHAKE_INITIATION, 0, 0, 0]
        {
            return Err(Error::InvalidPacket);
        }
        Ok(Self {
            sender_index: u32::from_le_bytes(value[4..8].try_into().unwrap()),
            ephemeral_public_key: value[8..40].try_into().unwrap(),
            static_public_key: value[40..88].try_into().unwrap(),
            timestamp: value[88..116].try_into().unwrap(),
            mac1: value[116..132].try_into().unwrap(),
            mac2: value[132..148].try_into().unwrap(),
        })
    }
}

impl Debug for HandshakeInitiation {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("HandshakeInitiation")
            .field("sender_index", &self.sender_index)
            .finish()
    }
}

pub struct HandshakeResponse {
    pub sender_index: u32,
    pub receiver_index: u32,
    pub ephemeral_public_key: [u8; 32],
    pub empty: [u8; 16],
    pub mac1: [u8; 16],
    pub mac2: [u8; 16],
}

impl TryFrom<&[u8]> for HandshakeResponse {
    type Error = Error;

    fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
        if value.len() != HANDSHAKE_RESPONSE_PACKET_SIZE
            || value[0..4] != [MESSAGE_TYPE_HANDSHAKE_RESPONSE, 0, 0, 0]
        {
            return Err(Error::InvalidPacket);
        }
        Ok(Self {
            sender_index: u32::from_le_bytes(value[4..8].try_into().unwrap()),
            receiver_index: u32::from_le_bytes(value[8..12].try_into().unwrap()),
            ephemeral_public_key: value[12..44].try_into().unwrap(),
            empty: value[44..60].try_into().unwrap(),
            mac1: value[60..76].try_into().unwrap(),
            mac2: value[76..92].try_into().unwrap(),
        })
    }
}

impl Debug for HandshakeResponse {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("HandshakeResponse")
            .field("sender_index", &self.sender_index)
            .field("receiver_index", &self.receiver_index)
            .finish()
    }
}

pub struct CookieReply {
    pub receiver_index: u32,
    pub nonce: [u8; 24],
    pub cookie: [u8; 16 + 16],
}

impl TryFrom<&[u8]> for CookieReply {
    type Error = Error;

    fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
        if value.len() != COOKIE_REPLY_PACKET_SIZE
            || value[0..4] != [MESSAGE_TYPE_COOKIE_REPLY, 0, 0, 0]
        {
            return Err(Error::InvalidPacket);
        }
        Ok(Self {
            receiver_index: u32::from_le_bytes(value[4..8].try_into().unwrap()),
            nonce: value[8..32].try_into().unwrap(),
            cookie: value[32..64].try_into().unwrap(),
        })
    }
}

impl Debug for CookieReply {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("CookieReply")
            .field("index", &self.receiver_index)
            .field("nonce", &self.nonce)
            .finish()
    }
}

pub struct TransportData {
    pub receiver_index: u32,
    pub counter: u64,
    pub payload: Vec<u8>,
}

impl TransportData {
    #[inline]
    pub fn packet_len(&self) -> usize {
        self.payload.len() + 16
    }
}

impl TransportData {
    pub fn to_bytes(&self) -> Vec<u8> {
        let mut bytes = Vec::with_capacity(self.payload.len() + 16);
        bytes.extend_from_slice(&[MESSAGE_TYPE_TRANSPORT_DATA, 0, 0, 0]);
        bytes.extend_from_slice(&self.receiver_index.to_le_bytes());
        bytes.extend_from_slice(&self.counter.to_le_bytes());
        bytes.extend_from_slice(&self.payload);
        bytes
    }
}

impl Debug for TransportData {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("TransportData")
            .field("receiver", &self.receiver_index)
            .field("counter", &self.counter)
            .field("len(payload)", &self.payload.len())
            .finish()
    }
}

impl TryFrom<&[u8]> for TransportData {
    type Error = Error;

    fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
        if value.len() < MIN_PACKET_SIZE || value[0..4] != [MESSAGE_TYPE_TRANSPORT_DATA, 0, 0, 0] {
            return Err(Error::InvalidPacket);
        }
        Ok(Self {
            receiver_index: u32::from_le_bytes(value[4..8].try_into().unwrap()),
            counter: u64::from_le_bytes(value[8..16].try_into().unwrap()),
            payload: value[16..].to_vec(),
        })
    }
}

#[derive(Debug)]
pub enum Message {
    HandshakeInitiation(HandshakeInitiation),
    HandshakeResponse(HandshakeResponse),
    CookieReply(CookieReply),
    TransportData(TransportData),
}

impl Message {
    pub fn parse(payload: &[u8]) -> Result<Message, Error> {
        if payload.len() < MIN_PACKET_SIZE {
            return Err(Error::InvalidPacket);
        }
        let message = match payload[0] {
            MESSAGE_TYPE_HANDSHAKE_INITIATION => {
                Message::HandshakeInitiation(HandshakeInitiation::try_from(payload)?)
            }
            MESSAGE_TYPE_HANDSHAKE_RESPONSE => {
                Message::HandshakeResponse(HandshakeResponse::try_from(payload)?)
            }
            MESSAGE_TYPE_COOKIE_REPLY => Message::CookieReply(CookieReply::try_from(payload)?),
            MESSAGE_TYPE_TRANSPORT_DATA => {
                Message::TransportData(TransportData::try_from(payload)?)
            }
            _ => return Err(Error::InvalidPacket),
        };

        Ok(message)
    }

    pub fn is_handshake(payload: &[u8]) -> bool {
        match payload[0] {
            MESSAGE_TYPE_HANDSHAKE_INITIATION
                if payload.len() == HANDSHAKE_INITIATION_PACKET_SIZE =>
            {
                true
            }
            MESSAGE_TYPE_HANDSHAKE_RESPONSE if payload.len() == HANDSHAKE_RESPONSE_PACKET_SIZE => {
                true
            }
            _ => false,
        }
    }
}