use super::handshake::{b2s_hash, b2s_keyed_mac_16, b2s_keyed_mac_16_2, b2s_mac_24};
use crate::noise::handshake::{LABEL_COOKIE, LABEL_MAC1};
use crate::noise::{HandshakeInit, HandshakeResponse, Packet, Tunn, TunnResult, WireGuardError};
#[cfg(feature = "mock-instant")]
use mock_instant::Instant;
use portable_atomic::{AtomicU64, Ordering};
use std::net::IpAddr;
#[cfg(not(feature = "mock-instant"))]
use crate::sleepyinstant::Instant;
use aead::generic_array::GenericArray;
use aead::{AeadInPlace, KeyInit};
use chacha20poly1305::{Key, XChaCha20Poly1305};
use parking_lot::Mutex;
use rand_core::{OsRng, RngCore};
use ring::constant_time::verify_slices_are_equal;
const COOKIE_REFRESH: u64 = 128; const COOKIE_SIZE: usize = 16;
const COOKIE_NONCE_SIZE: usize = 24;
const RESET_PERIOD: u64 = 1;
type Cookie = [u8; COOKIE_SIZE];
pub struct RateLimiter {
nonce_key: [u8; 32],
secret_key: [u8; 16],
start_time: Instant,
nonce_ctr: AtomicU64,
mac1_key: [u8; 32],
cookie_key: Key,
limit: u64,
count: AtomicU64,
last_reset: Mutex<Instant>,
}
impl RateLimiter {
pub fn new(public_key: &crate::x25519::PublicKey, limit: u64) -> Self {
let mut secret_key = [0u8; 16];
OsRng.fill_bytes(&mut secret_key);
RateLimiter {
nonce_key: Self::rand_bytes(),
secret_key,
start_time: Instant::now(),
nonce_ctr: AtomicU64::new(0),
mac1_key: b2s_hash(LABEL_MAC1, public_key.as_bytes()),
cookie_key: b2s_hash(LABEL_COOKIE, public_key.as_bytes()).into(),
limit,
count: AtomicU64::new(0),
last_reset: Mutex::new(Instant::now()),
}
}
fn rand_bytes() -> [u8; 32] {
let mut key = [0u8; 32];
OsRng.fill_bytes(&mut key);
key
}
pub fn reset_count(&self) {
let current_time = Instant::now();
let mut last_reset_time = self.last_reset.lock();
if current_time.duration_since(*last_reset_time).as_secs() >= RESET_PERIOD {
self.count.store(0, Ordering::SeqCst);
*last_reset_time = current_time;
}
}
fn current_cookie(&self, addr: IpAddr) -> Cookie {
let mut addr_bytes = [0u8; 16];
match addr {
IpAddr::V4(a) => addr_bytes[..4].copy_from_slice(&a.octets()[..]),
IpAddr::V6(a) => addr_bytes[..].copy_from_slice(&a.octets()[..]),
}
let cur_counter = Instant::now().duration_since(self.start_time).as_secs() / COOKIE_REFRESH;
b2s_keyed_mac_16_2(&self.secret_key, &cur_counter.to_le_bytes(), &addr_bytes)
}
fn nonce(&self) -> [u8; COOKIE_NONCE_SIZE] {
let ctr = self.nonce_ctr.fetch_add(1, Ordering::Relaxed);
b2s_mac_24(&self.nonce_key, &ctr.to_le_bytes())
}
fn is_under_load(&self) -> bool {
self.count.fetch_add(1, Ordering::SeqCst) >= self.limit
}
pub(crate) fn format_cookie_reply<'a>(
&self,
idx: u32,
cookie: Cookie,
mac1: &[u8],
dst: &'a mut [u8],
) -> Result<&'a mut [u8], WireGuardError> {
if dst.len() < super::COOKIE_REPLY_SZ {
return Err(WireGuardError::DestinationBufferTooSmall);
}
let (message_type, rest) = dst.split_at_mut(4);
let (receiver_index, rest) = rest.split_at_mut(4);
let (nonce, rest) = rest.split_at_mut(24);
let (encrypted_cookie, _) = rest.split_at_mut(16 + 16);
message_type.copy_from_slice(&super::COOKIE_REPLY.to_le_bytes());
receiver_index.copy_from_slice(&idx.to_le_bytes());
nonce.copy_from_slice(&self.nonce()[..]);
let cipher = XChaCha20Poly1305::new(&self.cookie_key);
let iv = GenericArray::from_slice(nonce);
encrypted_cookie[..16].copy_from_slice(&cookie);
let tag = cipher
.encrypt_in_place_detached(iv, mac1, &mut encrypted_cookie[..16])
.map_err(|_| WireGuardError::DestinationBufferTooSmall)?;
encrypted_cookie[16..].copy_from_slice(&tag);
Ok(&mut dst[..super::COOKIE_REPLY_SZ])
}
pub fn verify_packet<'a, 'b>(
&self,
src_addr: Option<IpAddr>,
src: &'a [u8],
dst: &'b mut [u8],
) -> Result<Packet<'a>, TunnResult<'b>> {
let packet = Tunn::parse_incoming_packet(src)?;
if let Packet::HandshakeInit(HandshakeInit { sender_idx, .. })
| Packet::HandshakeResponse(HandshakeResponse { sender_idx, .. }) = packet
{
let (msg, macs) = src.split_at(src.len() - 32);
let (mac1, mac2) = macs.split_at(16);
let computed_mac1 = b2s_keyed_mac_16(&self.mac1_key, msg);
verify_slices_are_equal(&computed_mac1[..16], mac1)
.map_err(|_| TunnResult::Err(WireGuardError::InvalidMac))?;
if self.is_under_load() {
let addr = match src_addr {
None => return Err(TunnResult::Err(WireGuardError::UnderLoad)),
Some(addr) => addr,
};
let cookie = self.current_cookie(addr);
let computed_mac2 = b2s_keyed_mac_16_2(&cookie, msg, mac1);
if verify_slices_are_equal(&computed_mac2[..16], mac2).is_err() {
let cookie_packet = self
.format_cookie_reply(sender_idx, cookie, mac1, dst)
.map_err(TunnResult::Err)?;
return Err(TunnResult::WriteToNetwork(cookie_packet));
}
}
}
Ok(packet)
}
}