pub mod errors;
pub mod handshake;
pub mod index_table;
pub mod rate_limiter;
mod session;
mod timers;
use rand::{Rng, RngCore, SeedableRng, rngs::StdRng};
use zerocopy::IntoBytes;
use crate::noise::errors::WireGuardError;
use crate::noise::handshake::Handshake;
use crate::noise::index_table::IndexTable;
use crate::noise::rate_limiter::RateLimiter;
use crate::noise::timers::{MAX_JITTER, TimerName, Timers};
use crate::packet::{Packet, WgCookieReply, WgData, WgHandshakeInit, WgHandshakeResp, WgKind};
use crate::tun::MtuWatcher;
use crate::x25519;
use std::collections::VecDeque;
use std::sync::Arc;
use std::time::Duration;
const MAX_QUEUE_DEPTH: usize = 256;
const N_SESSIONS: usize = 8;
#[derive(Debug)]
pub enum TunnResult {
Done,
Err(WireGuardError),
WriteToNetwork(WgKind),
WriteToTunnel(Packet),
}
impl From<WireGuardError> for TunnResult {
fn from(err: WireGuardError) -> TunnResult {
TunnResult::Err(err)
}
}
pub struct Tunn<R: RngCore + Send = StdRng> {
handshake: handshake::Handshake,
sessions: [Option<session::Session>; N_SESSIONS],
current: usize,
session_counter: usize,
packet_queue: VecDeque<Packet>,
timers: timers::Timers,
tx_bytes: usize,
rx_bytes: usize,
rate_limiter: Arc<RateLimiter>,
jitter_rng: R,
}
impl Tunn<StdRng> {
pub fn new(
static_private: x25519::StaticSecret,
peer_static_public: x25519::PublicKey,
preshared_key: Option<[u8; 32]>,
persistent_keepalive: Option<u16>,
index_table: IndexTable,
rate_limiter: Arc<RateLimiter>,
) -> Self {
Self::new_with_rng(
static_private,
peer_static_public,
preshared_key,
persistent_keepalive,
index_table,
rate_limiter,
StdRng::from_os_rng(),
)
}
}
impl<R: RngCore + Send> Tunn<R> {
pub fn new_with_rng(
static_private: x25519::StaticSecret,
peer_static_public: x25519::PublicKey,
preshared_key: Option<[u8; 32]>,
persistent_keepalive: Option<u16>,
index_table: IndexTable,
rate_limiter: Arc<RateLimiter>,
jitter_rng: R,
) -> Self {
let static_public = x25519::PublicKey::from(&static_private);
Tunn {
handshake: Handshake::new(
static_private,
static_public,
peer_static_public,
index_table,
preshared_key,
),
sessions: Default::default(),
current: Default::default(),
session_counter: Default::default(),
tx_bytes: Default::default(),
rx_bytes: Default::default(),
packet_queue: VecDeque::new(),
timers: Timers::new(persistent_keepalive),
rate_limiter,
jitter_rng,
}
}
pub fn is_expired(&self) -> bool {
self.handshake.is_expired()
}
pub fn set_static_private(
&mut self,
static_private: x25519::StaticSecret,
static_public: x25519::PublicKey,
rate_limiter: Arc<RateLimiter>,
) {
self.rate_limiter = rate_limiter;
self.handshake
.set_static_private(static_private, static_public);
for s in &mut self.sessions {
*s = None;
}
}
pub fn handle_outgoing_packet(
&mut self,
mut packet: Packet,
tun_mtu: Option<&mut MtuWatcher>,
) -> Option<WgKind> {
if let Some(tun_mtu) = tun_mtu {
packet = pad_to_x16(packet, tun_mtu);
}
match self.encapsulate_with_session(packet) {
Ok(encapsulated_packet) => Some(encapsulated_packet.into()),
Err(packet) => {
self.queue_packet(packet);
self.format_handshake_initiation(false).map(Into::into)
}
}
}
pub fn encapsulate_with_session(&mut self, packet: Packet) -> Result<Packet<WgData>, Packet> {
let current = self.current;
if let Some(ref session) = self.sessions[current % N_SESSIONS] {
let packet = session.format_packet_data(packet);
self.timer_tick(TimerName::TimeLastPacketSent);
if !packet.is_keepalive() {
self.timer_tick(TimerName::TimeLastDataPacketSent);
}
self.tx_bytes += packet.as_bytes().len();
Ok(packet)
} else {
Err(packet)
}
}
pub fn handle_incoming_packet(&mut self, packet: WgKind) -> TunnResult {
match packet {
WgKind::HandshakeInit(p) => self.handle_handshake_init(p),
WgKind::HandshakeResp(p) => self.handle_handshake_response(p),
WgKind::CookieReply(p) => self.handle_cookie_reply(&p),
WgKind::Data(p) => self.handle_data(p),
}
.unwrap_or_else(TunnResult::from)
}
fn handle_handshake_init(
&mut self,
p: Packet<WgHandshakeInit>,
) -> Result<TunnResult, WireGuardError> {
log::debug!("Received handshake_initiation: {}", p.sender_idx);
let (packet, session) = self.handshake.receive_handshake_initialization(p)?;
let slot = self.next_session_slot();
self.put_session(slot, session);
self.timer_tick(TimerName::TimeLastPacketReceived);
self.timer_tick(TimerName::TimeLastPacketSent);
self.timer_tick_session_established(false, slot);
log::debug!("Sending handshake_response: {slot}");
Ok(TunnResult::WriteToNetwork(packet.into()))
}
fn handle_handshake_response(
&mut self,
p: Packet<WgHandshakeResp>,
) -> Result<TunnResult, WireGuardError> {
log::debug!(
"Received handshake_response: {} {}",
p.receiver_idx,
p.sender_idx,
);
let session = self.handshake.receive_handshake_response(&p)?;
let mut p = p.into_bytes();
p.truncate(0);
let keepalive_packet = session.format_packet_data(p);
let slot = self.next_session_slot();
self.put_session(slot, session);
self.timer_tick(TimerName::TimeLastPacketReceived);
self.timer_tick_session_established(true, slot); self.set_current_session(slot);
log::debug!("Sending keepalive");
Ok(TunnResult::WriteToNetwork(keepalive_packet.into())) }
fn handle_cookie_reply(&mut self, p: &WgCookieReply) -> Result<TunnResult, WireGuardError> {
log::debug!("Received cookie_reply: {}", p.receiver_idx);
self.handshake.receive_cookie_reply(p)?;
self.timer_tick(TimerName::TimeLastPacketReceived);
self.timer_tick(TimerName::TimeCookieReceived);
Ok(TunnResult::Done)
}
fn set_current_session(&mut self, new_slot: usize) {
let cur_slot = self.current;
if cur_slot == new_slot {
return;
}
if self.sessions[cur_slot % N_SESSIONS].is_none()
|| self.timers.session_timers[new_slot % N_SESSIONS]
>= self.timers.session_timers[cur_slot % N_SESSIONS]
{
self.current = new_slot;
log::trace!("New session slot: {new_slot}");
}
}
fn next_session_slot(&mut self) -> usize {
let slot = self.session_counter % N_SESSIONS;
self.session_counter = self.session_counter.wrapping_add(1);
slot
}
fn put_session(&mut self, slot: usize, session: session::Session) {
self.sessions[slot % N_SESSIONS] = Some(session);
}
fn handle_data(&mut self, packet: Packet<WgData>) -> Result<TunnResult, WireGuardError> {
let decapsulated_packet = self.decapsulate_with_session(packet)?;
self.timer_tick(TimerName::TimeLastDataPacketReceived);
self.rx_bytes += decapsulated_packet.as_bytes().len();
Ok(TunnResult::WriteToTunnel(decapsulated_packet))
}
pub fn decapsulate_with_session(
&mut self,
packet: Packet<WgData>,
) -> Result<Packet, WireGuardError> {
let r_idx = packet.header.receiver_idx.get();
let (slot, session) = self
.sessions
.iter()
.enumerate()
.filter_map(|(i, s)| s.as_ref().map(|s| (i, s)))
.find(|(_, s)| s.receiving_index.value() == r_idx)
.ok_or_else(|| {
log::trace!("No session available: {r_idx}");
WireGuardError::NoCurrentSession
})?;
let decapsulated_packet = session.receive_packet_data(packet)?;
self.set_current_session(slot);
self.timer_tick(TimerName::TimeLastPacketReceived);
Ok(decapsulated_packet)
}
pub fn format_handshake_initiation(
&mut self,
force_resend: bool,
) -> Option<Packet<WgHandshakeInit>> {
if self.handshake.is_in_progress() && !force_resend {
return None;
}
if self.handshake.is_expired() {
self.timers.clear();
}
let starting_new_handshake = !self.handshake.is_in_progress();
let packet = self.handshake.format_handshake_initiation();
log::debug!("Sending handshake_initiation");
if starting_new_handshake {
self.timer_tick(TimerName::TimeLastHandshakeStarted);
}
self.timer_tick(TimerName::TimeLastPacketSent);
self.update_handshake_jitter();
Some(packet)
}
fn update_handshake_jitter(&mut self) {
self.timers.handshake_jitter = self.next_jitter();
}
fn next_jitter(&mut self) -> Duration {
self.jitter_rng.random_range(Duration::ZERO..=MAX_JITTER)
}
pub fn get_queued_packets(&mut self, tun_mtu: &mut MtuWatcher) -> impl Iterator<Item = WgKind> {
std::iter::from_fn(|| {
self.dequeue_packet()
.and_then(|packet| self.handle_outgoing_packet(packet, Some(tun_mtu)))
})
}
fn queue_packet(&mut self, packet: Packet) {
if self.packet_queue.len() < MAX_QUEUE_DEPTH {
self.packet_queue.push_back(packet);
}
}
fn dequeue_packet(&mut self) -> Option<Packet> {
self.packet_queue.pop_front()
}
fn estimate_loss(&self) -> f32 {
let session_idx = self.current;
let mut weight = 9.0;
let mut cur_avg = 0.0;
let mut total_weight = 0.0;
for i in 0..N_SESSIONS {
if let Some(ref session) = self.sessions[session_idx.wrapping_sub(i) % N_SESSIONS] {
let (expected, received) = session.current_packet_cnt();
let loss = if expected == 0 {
0.0
} else {
1.0 - received as f32 / expected as f32
};
cur_avg += loss * weight;
total_weight += weight;
weight /= 3.0;
}
}
if total_weight == 0.0 {
0.0
} else {
cur_avg / total_weight
}
}
pub fn stats(&self) -> (Option<Duration>, usize, usize, f32, Option<u32>) {
let time = self.time_since_last_handshake();
let tx_bytes = self.tx_bytes;
let rx_bytes = self.rx_bytes;
let loss = self.estimate_loss();
let rtt = self.handshake.last_rtt;
(time, tx_bytes, rx_bytes, loss, rtt)
}
}
fn pad_to_x16(mut packet: Packet, tun_mtu: &mut MtuWatcher) -> Packet {
if packet.len().is_multiple_of(16) {
return packet;
}
let padded_packet_len = {
let mtu = tun_mtu.get();
let mtu = usize::from(mtu);
if cfg!(debug_assertions) && packet.len() > mtu {
log::debug!("Packet length exceeded MTU: {} > {mtu}", packet.len());
}
packet.len().next_multiple_of(16).min(mtu).max(packet.len())
};
debug_assert!(padded_packet_len >= packet.len());
packet.buf_mut().resize(padded_packet_len, 0);
packet
}
#[cfg(test)]
mod tests {
use std::net::Ipv4Addr;
#[cfg(feature = "mock_instant")]
use crate::noise::timers::{MAX_JITTER, REKEY_AFTER_TIME, REKEY_TIMEOUT, TimerName};
use crate::packet::Ipv4;
const HANDSHAKE_RATE_LIMIT: u64 = 100;
use super::*;
use bytes::BytesMut;
#[cfg(feature = "mock_instant")]
use mock_instant::thread_local::MockClock;
fn create_two_tuns() -> (Tunn, Tunn) {
let my_secret_key = x25519_dalek::StaticSecret::random_from_rng(rand_core::OsRng);
let my_public_key = x25519_dalek::PublicKey::from(&my_secret_key);
let their_secret_key = x25519_dalek::StaticSecret::random_from_rng(rand_core::OsRng);
let their_public_key = x25519_dalek::PublicKey::from(&their_secret_key);
let rate_limiter = Arc::new(RateLimiter::new(&my_public_key, HANDSHAKE_RATE_LIMIT));
let my_tun = Tunn::new(
my_secret_key,
their_public_key,
None,
None,
IndexTable::from_os_rng(),
rate_limiter,
);
let rate_limiter = Arc::new(RateLimiter::new(&their_public_key, HANDSHAKE_RATE_LIMIT));
let their_tun = Tunn::new(
their_secret_key,
my_public_key,
None,
None,
IndexTable::from_os_rng(),
rate_limiter,
);
(my_tun, their_tun)
}
fn create_handshake_init(tun: &mut Tunn) -> Packet<WgHandshakeInit> {
tun.format_handshake_initiation(false)
.expect("expected handshake init")
}
fn create_handshake_response(
tun: &mut Tunn,
handshake_init: Packet<WgHandshakeInit>,
) -> Packet<WgHandshakeResp> {
let handshake_resp = tun.handle_incoming_packet(WgKind::HandshakeInit(handshake_init));
assert!(
matches!(handshake_resp, TunnResult::WriteToNetwork(_)),
"expected WriteToNetwork, {handshake_resp:?}"
);
let TunnResult::WriteToNetwork(handshake_resp) = handshake_resp else {
unreachable!("expected WriteToNetwork");
};
let WgKind::HandshakeResp(handshake_resp) = handshake_resp else {
unreachable!("expected WgHandshakeResp, got {handshake_resp:?}");
};
handshake_resp
}
fn parse_handshake_resp(
tun: &mut Tunn,
handshake_resp: Packet<WgHandshakeResp>,
) -> Packet<WgData> {
let keepalive = tun.handle_incoming_packet(WgKind::HandshakeResp(handshake_resp));
assert!(matches!(keepalive, TunnResult::WriteToNetwork(_)));
let TunnResult::WriteToNetwork(keepalive) = keepalive else {
unreachable!("expected WriteToNetwork")
};
let WgKind::Data(keepalive) = keepalive else {
unreachable!("expected WgData, got {keepalive:?}");
};
keepalive
}
fn parse_keepalive(tun: &mut Tunn, keepalive: Packet<WgData>) {
let result = tun.handle_incoming_packet(WgKind::Data(keepalive));
assert!(matches!(result, TunnResult::WriteToTunnel(p) if p.is_empty()));
}
fn create_two_tuns_and_handshake() -> (Tunn, Tunn) {
let (mut my_tun, mut their_tun) = create_two_tuns();
let init = create_handshake_init(&mut my_tun);
let resp = create_handshake_response(&mut their_tun, init);
let keepalive = parse_handshake_resp(&mut my_tun, resp);
parse_keepalive(&mut their_tun, keepalive);
(my_tun, their_tun)
}
fn create_ipv4_udp_packet() -> Packet<Ipv4> {
let header =
etherparse::PacketBuilder::ipv4([192, 168, 1, 2], [192, 168, 1, 3], 5).udp(5678, 23);
let payload = [0, 1, 2, 3];
let mut packet = Vec::<u8>::with_capacity(header.size(payload.len()));
header.write(&mut packet, &payload).unwrap();
let packet = Packet::from_bytes(BytesMut::from(&packet[..]));
packet.try_into_ipvx().unwrap().unwrap_left()
}
#[cfg(feature = "mock_instant")]
fn update_timer_results_in_handshake(tun: &mut Tunn) {
let packet = tun
.update_timers()
.expect("update_timers should succeed")
.unwrap();
assert!(matches!(packet, WgKind::HandshakeInit(..)));
}
#[test]
fn create_two_tunnels_linked_to_eachother() {
let (_my_tun, _their_tun) = create_two_tuns();
}
#[test]
fn handshake_init() {
let (mut my_tun, _their_tun) = create_two_tuns();
let _init = create_handshake_init(&mut my_tun);
}
#[test]
fn verify_handshake() {
let (mut my_tun, mut their_tun) = create_two_tuns();
let init = create_handshake_init(&mut my_tun);
let resp = create_handshake_response(&mut their_tun, init.clone());
their_tun
.rate_limiter
.verify_handshake(Ipv4Addr::LOCALHOST.into(), init)
.expect("Handshake init to be valid");
my_tun
.rate_limiter
.verify_handshake(Ipv4Addr::LOCALHOST.into(), resp)
.expect("Handshake response to be valid");
}
#[test]
#[cfg(feature = "mock_instant")]
fn verify_cookie_reply() {
let forced_handshake_init = |tun: &mut Tunn| {
tun.format_handshake_initiation(true)
.expect("expected handshake init")
};
let (mut my_tun, their_tun) = create_two_tuns();
for _ in 0..HANDSHAKE_RATE_LIMIT {
let init = forced_handshake_init(&mut my_tun);
their_tun
.rate_limiter
.verify_handshake(Ipv4Addr::LOCALHOST.into(), init)
.expect("Handshake init to be valid");
MockClock::advance(Duration::from_micros(1));
}
let init = forced_handshake_init(&mut my_tun);
let Err(TunnResult::WriteToNetwork(WgKind::CookieReply(cookie_resp))) = their_tun
.rate_limiter
.verify_handshake(Ipv4Addr::LOCALHOST.into(), init)
else {
panic!("expected cookie reply due to rate limiting");
};
my_tun
.handle_cookie_reply(&cookie_resp)
.expect("expected cookie reply to be valid");
let init = forced_handshake_init(&mut my_tun);
their_tun
.rate_limiter
.verify_handshake(Ipv4Addr::LOCALHOST.into(), init)
.expect("should accept handshake with cookie");
}
#[test]
fn reject_handshake() {
let (mut my_tun, mut their_tun) = create_two_tuns();
let mut init = create_handshake_init(&mut my_tun);
let mut resp = create_handshake_response(&mut their_tun, init.clone());
std::mem::swap(&mut init.mac1, &mut resp.mac1);
their_tun
.rate_limiter
.verify_handshake(Ipv4Addr::LOCALHOST.into(), init.clone())
.map(|packet| packet.mac1)
.expect_err("Handshake init to be invalid");
my_tun
.rate_limiter
.verify_handshake(Ipv4Addr::LOCALHOST.into(), resp)
.map(|packet| packet.mac1)
.expect_err("Handshake response to be invalid");
}
#[test]
fn handshake_init_and_response() {
let (mut my_tun, mut their_tun) = create_two_tuns();
let init = create_handshake_init(&mut my_tun);
let _resp = create_handshake_response(&mut their_tun, init);
}
#[test]
fn full_handshake() {
let (mut my_tun, mut their_tun) = create_two_tuns();
let init = create_handshake_init(&mut my_tun);
let resp = create_handshake_response(&mut their_tun, init);
let _keepalive = parse_handshake_resp(&mut my_tun, resp);
}
#[test]
fn full_handshake_plus_timers() {
let (mut my_tun, mut their_tun) = create_two_tuns_and_handshake();
assert!(matches!(my_tun.update_timers(), Ok(None)));
assert!(matches!(their_tun.update_timers(), Ok(None)));
}
#[test]
#[cfg(feature = "mock_instant")]
fn new_handshake_after_two_mins() {
let (mut my_tun, mut their_tun) = create_two_tuns_and_handshake();
MockClock::advance(Duration::from_secs(1));
assert!(matches!(their_tun.update_timers(), Ok(None)));
assert!(matches!(my_tun.update_timers(), Ok(None)));
let sent_packet_buf = create_ipv4_udp_packet();
let _data = my_tun
.handle_outgoing_packet(sent_packet_buf.into_bytes(), None)
.expect("expected encapsulated packet");
MockClock::advance(REKEY_AFTER_TIME);
assert!(matches!(their_tun.update_timers(), Ok(None)));
update_timer_results_in_handshake(&mut my_tun);
}
#[test]
#[cfg(feature = "mock_instant")]
fn handshake_no_resp_rekey_timeout() {
let (mut my_tun, _their_tun) = create_two_tuns();
let _init = create_handshake_init(&mut my_tun);
MockClock::advance(REKEY_TIMEOUT + MAX_JITTER + Duration::from_millis(1));
update_timer_results_in_handshake(&mut my_tun)
}
#[test]
fn one_ip_packet() {
let (mut my_tun, mut their_tun) = create_two_tuns_and_handshake();
let sent_packet_buf = create_ipv4_udp_packet();
let data = my_tun
.handle_outgoing_packet(sent_packet_buf.clone().into_bytes(), None)
.unwrap();
assert!(matches!(data, WgKind::Data(..)));
let data = their_tun.handle_incoming_packet(data);
let recv_packet_buf = if let TunnResult::WriteToTunnel(recv) = data {
recv
} else {
unreachable!("expected WritetoTunnelV4");
};
assert_eq!(sent_packet_buf.as_bytes(), recv_packet_buf.as_bytes());
}
#[test]
#[cfg(feature = "mock_instant")]
fn update_timers_handles_backward_time_jump() {
const PRESENT: Duration = Duration::from_secs(10);
const PAST: Duration = Duration::from_secs(5);
MockClock::set_time(Duration::ZERO);
let (mut my_tun, mut _their_tun) = create_two_tuns_and_handshake();
MockClock::advance(PRESENT);
my_tun.update_timers().unwrap();
let time_current_before = my_tun.timers[TimerName::TimeCurrent];
assert_eq!(time_current_before, PRESENT);
MockClock::set_time(PAST);
my_tun.update_timers().unwrap();
let time_current_after = my_tun.timers[TimerName::TimeCurrent];
assert_eq!(
time_current_after, PRESENT,
"TimeCurrent should never decrease"
);
}
#[test]
#[cfg(feature = "mock_instant")]
fn time_since_last_handshake_doesnt_decrease_on_backward_jump() {
const PRESENT: Duration = Duration::from_secs(60);
MockClock::set_time(Duration::ZERO);
let (mut my_tun, mut _their_tun) = create_two_tuns_and_handshake();
MockClock::advance(PRESENT);
my_tun.update_timers().unwrap();
let time_since = my_tun.time_since_last_handshake().expect("have handshake");
assert!(time_since >= PRESENT);
assert!(time_since > Duration::ZERO);
MockClock::set_time(Duration::ZERO);
my_tun.update_timers().unwrap();
let time_since_after_jump = my_tun.time_since_last_handshake();
assert_eq!(
time_since_after_jump,
Some(PRESENT),
"time_since_last_handshake should never decrease"
);
}
#[test]
#[cfg(feature = "mock_instant")]
fn handshake_jitter_applied() {
struct FixedRng(u32);
impl rand::RngCore for FixedRng {
fn next_u32(&mut self) -> u32 {
self.0
}
fn next_u64(&mut self) -> u64 {
u64::from(self.0)
}
fn fill_bytes(&mut self, dest: &mut [u8]) {
dest.fill(0);
}
}
MockClock::set_time(Duration::ZERO);
let my_secret_key = x25519_dalek::StaticSecret::random_from_rng(rand_core::OsRng);
let my_public_key = x25519_dalek::PublicKey::from(&my_secret_key);
let their_secret_key = x25519_dalek::StaticSecret::random_from_rng(rand_core::OsRng);
let their_public_key = x25519_dalek::PublicKey::from(&their_secret_key);
let rate_limiter = Arc::new(RateLimiter::new(&my_public_key, HANDSHAKE_RATE_LIMIT));
let mut my_tun = Tunn::new_with_rng(
my_secret_key,
their_public_key,
None,
None,
IndexTable::from_os_rng(),
rate_limiter,
FixedRng(200),
);
let expected_jitter = my_tun.next_jitter();
let packet = create_ipv4_udp_packet();
let _ = my_tun.handle_outgoing_packet(packet.into_bytes(), None);
MockClock::advance(REKEY_TIMEOUT + expected_jitter - Duration::from_millis(1));
assert!(
matches!(my_tun.update_timers(), Ok(None)),
"retry should not fire before REKEY_TIMEOUT + jitter"
);
MockClock::advance(Duration::from_millis(1));
assert!(
matches!(my_tun.update_timers(), Ok(Some(WgKind::HandshakeInit(..)))),
"retry should fire at REKEY_TIMEOUT + jitter"
);
}
#[test]
#[cfg(feature = "mock_instant")]
fn per_ip_rate_limiting_isolation() {
let (mut my_tun, their_tun) = create_two_tuns();
let attacker_ip = Ipv4Addr::new(10, 0, 0, 1);
let legit_ip = Ipv4Addr::new(10, 0, 0, 2);
for _ in 0..HANDSHAKE_RATE_LIMIT {
let init = my_tun
.format_handshake_initiation(true)
.expect("expected handshake init");
their_tun
.rate_limiter
.verify_handshake(attacker_ip.into(), init)
.expect("should be under limit");
MockClock::advance(Duration::from_micros(1));
}
let init = my_tun
.format_handshake_initiation(true)
.expect("expected handshake init");
assert!(
matches!(
their_tun
.rate_limiter
.verify_handshake(attacker_ip.into(), init),
Err(TunnResult::WriteToNetwork(WgKind::CookieReply(_)))
),
"attacker IP should be rate limited"
);
let init = my_tun
.format_handshake_initiation(true)
.expect("expected handshake init");
their_tun
.rate_limiter
.verify_handshake(legit_ip.into(), init)
.expect("legitimate IP should not be rate limited");
}
#[test]
#[cfg(feature = "mock_instant")]
fn timers_freeze_during_backward_jump() {
const INITIAL_TIME: Duration = Duration::from_secs(100);
const JUMPED_BACK_TIME: Duration = Duration::from_secs(95);
const RESUMED_TIME: Duration = Duration::from_secs(105);
MockClock::set_time(Duration::ZERO);
let (mut my_tun, mut _their_tun) = create_two_tuns_and_handshake();
MockClock::set_time(INITIAL_TIME);
my_tun.update_timers().unwrap();
assert_eq!(my_tun.timers[TimerName::TimeCurrent], INITIAL_TIME);
MockClock::set_time(JUMPED_BACK_TIME);
my_tun.update_timers().unwrap();
assert_eq!(my_tun.timers[TimerName::TimeCurrent], INITIAL_TIME);
MockClock::set_time(RESUMED_TIME);
my_tun.update_timers().unwrap();
assert_eq!(my_tun.timers[TimerName::TimeCurrent], RESUMED_TIME);
}
}