gotatun 0.1.1

an implementation of the WireGuard® protocol designed for portability and speed
Documentation
// Copyright (c) 2019 Cloudflare, Inc. All rights reserved.
//
// Modified by Mullvad VPN.
// Copyright (c) 2025 Mullvad VPN.
//
// SPDX-License-Identifier: BSD-3-Clause

use super::handshake::{b2s_hash, b2s_keyed_mac_16, b2s_keyed_mac_16_2, b2s_mac_24};
use crate::noise::handshake::{LABEL_COOKIE, LABEL_MAC1};
use crate::noise::{TunnResult, WireGuardError};
use crate::packet::{Packet, WgCookieReply, WgHandshakeBase, WgKind};

use constant_time_eq::constant_time_eq;
#[cfg(feature = "mock_instant")]
use mock_instant::Instant;
use std::net::IpAddr;
use std::sync::atomic::{AtomicU64, Ordering};

#[cfg(not(feature = "mock_instant"))]
use crate::sleepyinstant::Instant;

use aead::generic_array::GenericArray;
use aead::{AeadInPlace, KeyInit};
use chacha20poly1305::{Key, XChaCha20Poly1305};
use parking_lot::Mutex;
use rand_core::{OsRng, RngCore};

const COOKIE_REFRESH: u64 = 128; // Use 128 and not 120 so the compiler can optimize out the division
const COOKIE_SIZE: usize = 16;
const COOKIE_NONCE_SIZE: usize = 24;

/// How often should reset count in seconds
const RESET_PERIOD: u64 = 1;

type Cookie = [u8; COOKIE_SIZE];

/// There are two places where WireGuard requires "randomness" for cookies
/// * The 24 byte nonce in the cookie massage - here the only goal is to avoid nonce reuse
/// * A secret value that changes every two minutes
///
/// Because the main goal of the cookie is simply for a party to prove ownership of an IP address
/// we can relax the randomness definition a bit, in order to avoid locking, because using less
/// resources is the main goal of any DoS prevention mechanism.
/// In order to avoid locking and calls to rand we derive pseudo random values using the AEAD and
/// some counters.
pub struct RateLimiter {
    /// The key we use to derive the nonce
    nonce_key: [u8; 32],
    /// The key we use to derive the cookie
    secret_key: [u8; 16],
    start_time: Instant,
    /// A single 64 bit counter (should suffice for many years)
    nonce_ctr: AtomicU64,
    mac1_key: [u8; 32],
    cookie_key: Key,
    limit: u64,
    /// The counter since last reset
    count: AtomicU64,
    /// The time last reset was performed on this rate limiter
    last_reset: Mutex<Instant>,
}

impl RateLimiter {
    pub fn new(public_key: &crate::x25519::PublicKey, limit: u64) -> Self {
        let mut secret_key = [0u8; 16];
        OsRng.fill_bytes(&mut secret_key);
        RateLimiter {
            nonce_key: Self::rand_bytes(),
            secret_key,
            start_time: Instant::now(),
            nonce_ctr: AtomicU64::new(0),
            mac1_key: b2s_hash(LABEL_MAC1, public_key.as_bytes()),
            cookie_key: b2s_hash(LABEL_COOKIE, public_key.as_bytes()).into(),
            limit,
            count: AtomicU64::new(0),
            last_reset: Mutex::new(Instant::now()),
        }
    }

    fn rand_bytes() -> [u8; 32] {
        let mut key = [0u8; 32];
        OsRng.fill_bytes(&mut key);
        key
    }

    /// Reset packet count (ideally should be called with a period of 1 second)
    pub fn reset_count(&self) {
        // The rate limiter is not very accurate, but at the scale we care about it doesn't matter much
        let current_time = Instant::now();
        let mut last_reset_time = self.last_reset.lock();
        if current_time.duration_since(*last_reset_time).as_secs() >= RESET_PERIOD {
            self.count.store(0, Ordering::SeqCst);
            *last_reset_time = current_time;
        }
    }

    /// Compute the correct cookie value based on the current secret value and the source IP
    fn current_cookie(&self, addr: IpAddr) -> Cookie {
        let mut addr_bytes = [0u8; 16];

        match addr {
            IpAddr::V4(a) => addr_bytes[..4].copy_from_slice(&a.octets()[..]),
            IpAddr::V6(a) => addr_bytes[..].copy_from_slice(&a.octets()[..]),
        }

        // The current cookie for a given IP is the MAC(responder.changing_secret_every_two_minutes, initiator.ip_address)
        // First we derive the secret from the current time, the value of cur_counter would change with time.
        let cur_counter = Instant::now().duration_since(self.start_time).as_secs() / COOKIE_REFRESH;

        // Next we derive the cookie
        b2s_keyed_mac_16_2(&self.secret_key, &cur_counter.to_le_bytes(), &addr_bytes)
    }

    fn nonce(&self) -> [u8; COOKIE_NONCE_SIZE] {
        let ctr = self.nonce_ctr.fetch_add(1, Ordering::Relaxed);

        b2s_mac_24(&self.nonce_key, &ctr.to_le_bytes())
    }

    fn is_under_load(&self) -> bool {
        self.count.fetch_add(1, Ordering::SeqCst) >= self.limit
    }

    pub(crate) fn format_cookie_reply(
        &self,
        idx: u32,
        cookie: Cookie,
        mac1: &[u8],
    ) -> WgCookieReply {
        let mut wg_cookie_reply = WgCookieReply::new();

        // msg.message_type = 3
        // msg.reserved_zero = { 0, 0, 0 }
        // msg.receiver_index = little_endian(initiator.sender_index)
        wg_cookie_reply.receiver_idx.set(idx);
        wg_cookie_reply.nonce = self.nonce();

        let cipher = XChaCha20Poly1305::new(&self.cookie_key);

        let iv = GenericArray::from_slice(&wg_cookie_reply.nonce);

        wg_cookie_reply.encrypted_cookie[..16].copy_from_slice(&cookie);
        let tag = cipher
            .encrypt_in_place_detached(iv, mac1, &mut wg_cookie_reply.encrypted_cookie[..16])
            .expect("wg_cookie_reply is large enough");

        wg_cookie_reply.encrypted_cookie[16..].copy_from_slice(&tag);

        wg_cookie_reply
    }

    /// Decode the packet as wireguard packet.
    /// Then, verify the MAC fields on the packet (if any), and apply rate limiting if needed.
    pub fn verify_packet(
        &self,
        src_addr: Option<IpAddr>,
        packet: Packet,
    ) -> Result<WgKind, TunnResult> {
        let packet = packet
            .try_into_wg()
            .map_err(|_err| TunnResult::Err(WireGuardError::InvalidPacket))?;

        // Verify and rate limit handshake messages only
        match packet {
            WgKind::HandshakeInit(packet) => self
                .verify_handshake(src_addr, packet)
                .map(WgKind::HandshakeInit),
            WgKind::HandshakeResp(packet) => self
                .verify_handshake(src_addr, packet)
                .map(WgKind::HandshakeResp),
            _ => Ok(packet),
        }
    }

    /// Verify the MAC fields on the handshake, and apply rate limiting if needed.
    pub(crate) fn verify_handshake<P: WgHandshakeBase>(
        &self,
        src_addr: Option<IpAddr>,
        handshake: Packet<P>,
    ) -> Result<Packet<P>, TunnResult> {
        let sender_idx = handshake.sender_idx();
        let mac1 = handshake.mac1();
        let mac2 = handshake.mac2();

        let packet_until_mac = &handshake.as_bytes()[..P::MAC1_OFF];

        let computed_mac1 = b2s_keyed_mac_16(&self.mac1_key, packet_until_mac);
        if !constant_time_eq(&computed_mac1, mac1) {
            return Err(TunnResult::Err(WireGuardError::InvalidMac));
        }

        if self.is_under_load() {
            let addr = match src_addr {
                None => return Err(TunnResult::Err(WireGuardError::UnderLoad)),
                Some(addr) => addr,
            };

            // Only given an address can we validate mac2
            let cookie = self.current_cookie(addr);
            let computed_mac2 = b2s_keyed_mac_16_2(&cookie, packet_until_mac, mac1);

            if !constant_time_eq(&computed_mac2, mac2) {
                let cookie_reply = self.format_cookie_reply(sender_idx, cookie, mac1);
                let packet = handshake.overwrite_with(&cookie_reply);
                return Err(TunnResult::WriteToNetwork(packet.into()));
            }
        }

        Ok(handshake)
    }
}