pub mod errors;
pub mod handshake;
pub mod rate_limiter;
mod session;
mod timers;
use zerocopy::IntoBytes;
use crate::noise::errors::WireGuardError;
use crate::noise::handshake::Handshake;
use crate::noise::rate_limiter::RateLimiter;
use crate::noise::timers::{TimerName, Timers};
use crate::packet::{Packet, WgCookieReply, WgData, WgHandshakeInit, WgHandshakeResp, WgKind};
use crate::x25519;
use std::collections::VecDeque;
use std::sync::Arc;
use std::time::Duration;
const PEER_HANDSHAKE_RATE_LIMIT: u64 = 10;
const MAX_QUEUE_DEPTH: usize = 256;
const N_SESSIONS: usize = 8;
#[derive(Debug)]
pub enum TunnResult {
Done,
Err(WireGuardError),
WriteToNetwork(WgKind),
WriteToTunnel(Packet),
}
impl From<WireGuardError> for TunnResult {
fn from(err: WireGuardError) -> TunnResult {
TunnResult::Err(err)
}
}
pub struct Tunn {
handshake: handshake::Handshake,
sessions: [Option<session::Session>; N_SESSIONS],
current: usize,
packet_queue: VecDeque<Packet>,
timers: timers::Timers,
tx_bytes: usize,
rx_bytes: usize,
rate_limiter: Arc<RateLimiter>,
}
impl Tunn {
pub fn is_expired(&self) -> bool {
self.handshake.is_expired()
}
pub fn new(
static_private: x25519::StaticSecret,
peer_static_public: x25519::PublicKey,
preshared_key: Option<[u8; 32]>,
persistent_keepalive: Option<u16>,
index: u32,
rate_limiter: Option<Arc<RateLimiter>>,
) -> Self {
let static_public = x25519::PublicKey::from(&static_private);
Tunn {
handshake: Handshake::new(
static_private,
static_public,
peer_static_public,
index << 8,
preshared_key,
),
sessions: Default::default(),
current: Default::default(),
tx_bytes: Default::default(),
rx_bytes: Default::default(),
packet_queue: VecDeque::new(),
timers: Timers::new(persistent_keepalive, rate_limiter.is_none()),
rate_limiter: rate_limiter.unwrap_or_else(|| {
Arc::new(RateLimiter::new(&static_public, PEER_HANDSHAKE_RATE_LIMIT))
}),
}
}
pub fn set_static_private(
&mut self,
static_private: x25519::StaticSecret,
static_public: x25519::PublicKey,
rate_limiter: Option<Arc<RateLimiter>>,
) {
self.timers.should_reset_rr = rate_limiter.is_none();
self.rate_limiter = rate_limiter.unwrap_or_else(|| {
Arc::new(RateLimiter::new(&static_public, PEER_HANDSHAKE_RATE_LIMIT))
});
self.handshake
.set_static_private(static_private, static_public);
for s in &mut self.sessions {
*s = None;
}
}
pub fn handle_outgoing_packet(&mut self, packet: Packet) -> Option<WgKind> {
match self.encapsulate_with_session(packet) {
Ok(encapsulated_packet) => Some(encapsulated_packet.into()),
Err(packet) => {
self.queue_packet(packet);
self.format_handshake_initiation(false).map(Into::into)
}
}
}
pub fn encapsulate_with_session(&mut self, packet: Packet) -> Result<Packet<WgData>, Packet> {
let current = self.current;
if let Some(ref session) = self.sessions[current % N_SESSIONS] {
let packet = session.format_packet_data(packet);
self.timer_tick(TimerName::TimeLastPacketSent);
if !packet.as_bytes().is_empty() {
self.timer_tick(TimerName::TimeLastDataPacketSent);
}
self.tx_bytes += packet.as_bytes().len();
Ok(packet)
} else {
Err(packet)
}
}
pub fn handle_incoming_packet(&mut self, packet: WgKind) -> TunnResult {
match packet {
WgKind::HandshakeInit(p) => self.handle_handshake_init(p),
WgKind::HandshakeResp(p) => self.handle_handshake_response(p),
WgKind::CookieReply(p) => self.handle_cookie_reply(&p),
WgKind::Data(p) => self.handle_data(p),
}
.unwrap_or_else(TunnResult::from)
}
fn handle_handshake_init(
&mut self,
p: Packet<WgHandshakeInit>,
) -> Result<TunnResult, WireGuardError> {
log::debug!("Received handshake_initiation: {}", p.sender_idx);
let (packet, session) = self.handshake.receive_handshake_initialization(p)?;
let index = session.local_index();
self.sessions[index % N_SESSIONS] = Some(session);
self.timer_tick(TimerName::TimeLastPacketReceived);
self.timer_tick(TimerName::TimeLastPacketSent);
self.timer_tick_session_established(false, index);
log::debug!("Sending handshake_response: {index}");
Ok(TunnResult::WriteToNetwork(packet.into()))
}
fn handle_handshake_response(
&mut self,
p: Packet<WgHandshakeResp>,
) -> Result<TunnResult, WireGuardError> {
log::debug!(
"Received handshake_response: {} {}",
p.receiver_idx,
p.sender_idx,
);
let session = self.handshake.receive_handshake_response(&p)?;
let mut p = p.into_bytes();
p.truncate(0);
let keepalive_packet = session.format_packet_data(p);
let l_idx = session.local_index();
let index = l_idx % N_SESSIONS;
self.sessions[index] = Some(session);
self.timer_tick(TimerName::TimeLastPacketReceived);
self.timer_tick_session_established(true, index); self.set_current_session(l_idx);
log::debug!("Sending keepalive");
Ok(TunnResult::WriteToNetwork(keepalive_packet.into())) }
fn handle_cookie_reply(&mut self, p: &WgCookieReply) -> Result<TunnResult, WireGuardError> {
log::debug!("Received cookie_reply: {}", p.receiver_idx);
self.handshake.receive_cookie_reply(p)?;
self.timer_tick(TimerName::TimeLastPacketReceived);
self.timer_tick(TimerName::TimeCookieReceived);
log::debug!("Did set cookie");
Ok(TunnResult::Done)
}
fn set_current_session(&mut self, new_idx: usize) {
let cur_idx = self.current;
if cur_idx == new_idx {
return;
}
if self.sessions[cur_idx % N_SESSIONS].is_none()
|| self.timers.session_timers[new_idx % N_SESSIONS]
>= self.timers.session_timers[cur_idx % N_SESSIONS]
{
self.current = new_idx;
log::debug!("New session: {new_idx}");
}
}
fn handle_data(&mut self, packet: Packet<WgData>) -> Result<TunnResult, WireGuardError> {
let decapsulated_packet = self.decapsulate_with_session(packet)?;
self.timer_tick(TimerName::TimeLastDataPacketReceived);
self.rx_bytes += decapsulated_packet.as_bytes().len();
Ok(TunnResult::WriteToTunnel(decapsulated_packet))
}
pub fn decapsulate_with_session(
&mut self,
packet: Packet<WgData>,
) -> Result<Packet, WireGuardError> {
let r_idx = packet.header.receiver_idx.get() as usize;
let idx = r_idx % N_SESSIONS;
let decapsulated_packet = {
let session = self.sessions[idx].as_ref();
let session = session.ok_or_else(|| {
log::trace!("No current session available: {r_idx}");
WireGuardError::NoCurrentSession
})?;
session.receive_packet_data(packet)?
};
self.set_current_session(r_idx);
self.timer_tick(TimerName::TimeLastPacketReceived);
Ok(decapsulated_packet)
}
pub fn format_handshake_initiation(
&mut self,
force_resend: bool,
) -> Option<Packet<WgHandshakeInit>> {
if self.handshake.is_in_progress() && !force_resend {
return None;
}
if self.handshake.is_expired() {
self.timers.clear();
}
let starting_new_handshake = !self.handshake.is_in_progress();
let packet = self.handshake.format_handshake_initiation();
log::debug!("Sending handshake_initiation");
if starting_new_handshake {
self.timer_tick(TimerName::TimeLastHandshakeStarted);
}
self.timer_tick(TimerName::TimeLastPacketSent);
Some(packet)
}
pub fn next_queued_packet(&mut self) -> Option<WgKind> {
self.dequeue_packet()
.and_then(|packet| self.handle_outgoing_packet(packet))
}
fn queue_packet(&mut self, packet: Packet) {
if self.packet_queue.len() < MAX_QUEUE_DEPTH {
self.packet_queue.push_back(packet);
}
}
fn dequeue_packet(&mut self) -> Option<Packet> {
self.packet_queue.pop_front()
}
fn estimate_loss(&self) -> f32 {
let session_idx = self.current;
let mut weight = 9.0;
let mut cur_avg = 0.0;
let mut total_weight = 0.0;
for i in 0..N_SESSIONS {
if let Some(ref session) = self.sessions[(session_idx.wrapping_sub(i)) % N_SESSIONS] {
let (expected, received) = session.current_packet_cnt();
let loss = if expected == 0 {
0.0
} else {
1.0 - received as f32 / expected as f32
};
cur_avg += loss * weight;
total_weight += weight;
weight /= 3.0;
}
}
if total_weight == 0.0 {
0.0
} else {
cur_avg / total_weight
}
}
pub fn stats(&self) -> (Option<Duration>, usize, usize, f32, Option<u32>) {
let time = self.time_since_last_handshake();
let tx_bytes = self.tx_bytes;
let rx_bytes = self.rx_bytes;
let loss = self.estimate_loss();
let rtt = self.handshake.last_rtt;
(time, tx_bytes, rx_bytes, loss, rtt)
}
}
#[cfg(test)]
mod tests {
#[cfg(feature = "mock_instant")]
use crate::noise::timers::{REKEY_AFTER_TIME, REKEY_TIMEOUT};
use crate::packet::Ipv4;
use super::*;
use bytes::BytesMut;
use rand_core::{OsRng, RngCore};
fn create_two_tuns() -> (Tunn, Tunn) {
let my_secret_key = x25519_dalek::StaticSecret::random_from_rng(OsRng);
let my_public_key = x25519_dalek::PublicKey::from(&my_secret_key);
let my_idx = OsRng.next_u32();
let their_secret_key = x25519_dalek::StaticSecret::random_from_rng(OsRng);
let their_public_key = x25519_dalek::PublicKey::from(&their_secret_key);
let their_idx = OsRng.next_u32();
let my_tun = Tunn::new(my_secret_key, their_public_key, None, None, my_idx, None);
let their_tun = Tunn::new(their_secret_key, my_public_key, None, None, their_idx, None);
(my_tun, their_tun)
}
fn create_handshake_init(tun: &mut Tunn) -> Packet<WgHandshakeInit> {
tun.format_handshake_initiation(false)
.expect("expected handshake init")
}
fn create_handshake_response(
tun: &mut Tunn,
handshake_init: Packet<WgHandshakeInit>,
) -> Packet<WgHandshakeResp> {
let handshake_resp = tun.handle_incoming_packet(WgKind::HandshakeInit(handshake_init));
assert!(matches!(handshake_resp, TunnResult::WriteToNetwork(_)));
let TunnResult::WriteToNetwork(handshake_resp) = handshake_resp else {
unreachable!("expected WriteToNetwork");
};
let WgKind::HandshakeResp(handshake_resp) = handshake_resp else {
unreachable!("expected WgHandshakeResp, got {handshake_resp:?}");
};
handshake_resp
}
fn parse_handshake_resp(
tun: &mut Tunn,
handshake_resp: Packet<WgHandshakeResp>,
) -> Packet<WgData> {
let keepalive = tun.handle_incoming_packet(WgKind::HandshakeResp(handshake_resp));
assert!(matches!(keepalive, TunnResult::WriteToNetwork(_)));
let TunnResult::WriteToNetwork(keepalive) = keepalive else {
unreachable!("expected WriteToNetwork")
};
let WgKind::Data(keepalive) = keepalive else {
unreachable!("expected WgData, got {keepalive:?}");
};
keepalive
}
fn parse_keepalive(tun: &mut Tunn, keepalive: Packet<WgData>) {
let result = tun.handle_incoming_packet(WgKind::Data(keepalive));
assert!(matches!(result, TunnResult::WriteToTunnel(p) if p.is_empty()));
}
fn create_two_tuns_and_handshake() -> (Tunn, Tunn) {
let (mut my_tun, mut their_tun) = create_two_tuns();
let init = create_handshake_init(&mut my_tun);
let resp = create_handshake_response(&mut their_tun, init);
let keepalive = parse_handshake_resp(&mut my_tun, resp);
parse_keepalive(&mut their_tun, keepalive);
(my_tun, their_tun)
}
fn create_ipv4_udp_packet() -> Packet<Ipv4> {
let header =
etherparse::PacketBuilder::ipv4([192, 168, 1, 2], [192, 168, 1, 3], 5).udp(5678, 23);
let payload = [0, 1, 2, 3];
let mut packet = Vec::<u8>::with_capacity(header.size(payload.len()));
header.write(&mut packet, &payload).unwrap();
let packet = Packet::from_bytes(BytesMut::from(&packet[..]));
packet.try_into_ipvx().unwrap().unwrap_left()
}
#[cfg(feature = "mock_instant")]
fn update_timer_results_in_handshake(tun: &mut Tunn) {
let packet = tun
.update_timers()
.expect("update_timers should succeed")
.unwrap();
assert!(matches!(packet, WgKind::HandshakeInit(..)));
}
#[test]
fn create_two_tunnels_linked_to_eachother() {
let (_my_tun, _their_tun) = create_two_tuns();
}
#[test]
fn handshake_init() {
let (mut my_tun, _their_tun) = create_two_tuns();
let _init = create_handshake_init(&mut my_tun);
}
#[test]
fn verify_handshake() {
let (mut my_tun, mut their_tun) = create_two_tuns();
let init = create_handshake_init(&mut my_tun);
let resp = create_handshake_response(&mut their_tun, init.clone());
their_tun
.rate_limiter
.verify_handshake(None, init)
.expect("Handshake init to be valid");
my_tun
.rate_limiter
.verify_handshake(None, resp)
.expect("Handshake response to be valid");
}
#[test]
fn reject_handshake() {
let (mut my_tun, mut their_tun) = create_two_tuns();
let mut init = create_handshake_init(&mut my_tun);
let mut resp = create_handshake_response(&mut their_tun, init.clone());
std::mem::swap(&mut init.mac1, &mut resp.mac1);
their_tun
.rate_limiter
.verify_handshake(None, init.clone())
.map(|packet| packet.mac1)
.expect_err("Handshake init to be invalid");
my_tun
.rate_limiter
.verify_handshake(None, resp)
.map(|packet| packet.mac1)
.expect_err("Handshake response to be invalid");
}
#[test]
fn handshake_init_and_response() {
let (mut my_tun, mut their_tun) = create_two_tuns();
let init = create_handshake_init(&mut my_tun);
let _resp = create_handshake_response(&mut their_tun, init);
}
#[test]
fn full_handshake() {
let (mut my_tun, mut their_tun) = create_two_tuns();
let init = create_handshake_init(&mut my_tun);
let resp = create_handshake_response(&mut their_tun, init);
let _keepalive = parse_handshake_resp(&mut my_tun, resp);
}
#[test]
fn full_handshake_plus_timers() {
let (mut my_tun, mut their_tun) = create_two_tuns_and_handshake();
assert!(matches!(my_tun.update_timers(), Ok(None)));
assert!(matches!(their_tun.update_timers(), Ok(None)));
}
#[test]
#[cfg(feature = "mock_instant")]
fn new_handshake_after_two_mins() {
let (mut my_tun, mut their_tun) = create_two_tuns_and_handshake();
mock_instant::MockClock::advance(Duration::from_secs(1));
assert!(matches!(their_tun.update_timers(), Ok(None)));
assert!(matches!(my_tun.update_timers(), Ok(None)));
let sent_packet_buf = create_ipv4_udp_packet();
let _data = my_tun
.handle_outgoing_packet(sent_packet_buf.into_bytes())
.expect("expected encapsulated packet");
mock_instant::MockClock::advance(REKEY_AFTER_TIME);
assert!(matches!(their_tun.update_timers(), Ok(None)));
update_timer_results_in_handshake(&mut my_tun);
}
#[test]
#[cfg(feature = "mock_instant")]
fn handshake_no_resp_rekey_timeout() {
let (mut my_tun, _their_tun) = create_two_tuns();
let _init = create_handshake_init(&mut my_tun);
mock_instant::MockClock::advance(REKEY_TIMEOUT);
update_timer_results_in_handshake(&mut my_tun)
}
#[test]
fn one_ip_packet() {
let (mut my_tun, mut their_tun) = create_two_tuns_and_handshake();
let sent_packet_buf = create_ipv4_udp_packet();
let data = my_tun
.handle_outgoing_packet(sent_packet_buf.clone().into_bytes())
.unwrap();
assert!(matches!(data, WgKind::Data(..)));
let data = their_tun.handle_incoming_packet(data);
let recv_packet_buf = if let TunnResult::WriteToTunnel(recv) = data {
recv
} else {
unreachable!("expected WritetoTunnelV4");
};
assert_eq!(sent_packet_buf.as_bytes(), recv_packet_buf.as_bytes());
}
}