use std::{
collections::{HashMap, VecDeque},
net::SocketAddr,
sync::Arc,
time::Instant,
};
use ana_gotatun::{
noise::{Tunn, TunnResult, handshake::parse_handshake_anon, rate_limiter::RateLimiter},
packet::{Packet, WgKind},
x25519,
};
pub struct SnapTunServer<T> {
static_private: x25519::StaticSecret,
static_public: x25519::PublicKey,
active_tunnels: HashMap<SocketAddr, (x25519::PublicKey, Tunn)>,
rate_limiter: Arc<RateLimiter>,
authz: Arc<T>,
}
impl<T: SnapTunAuthorization> SnapTunServer<T> {
pub fn new(
static_private: x25519::StaticSecret,
rate_limiter: Arc<RateLimiter>,
authz: Arc<T>,
) -> Self {
let static_public = x25519::PublicKey::from(&static_private);
Self {
static_private,
static_public,
active_tunnels: Default::default(),
rate_limiter,
authz,
}
}
#[tracing::instrument(skip_all, fields(remote = %from))]
pub fn handle_incoming_packet(
&mut self,
packet: Packet,
from: SocketAddr,
send_to_network: &mut VecDeque<WgKind>,
) -> TunnResult {
let now = Instant::now();
let parsed_packet = match self.rate_limiter.verify_packet(from.ip(), packet) {
Ok(p) => p,
Err(TunnResult::WriteToNetwork(c)) => {
tracing::debug!(remote = ?from, "rate limiter issued cookie reply");
send_to_network.push_back(c);
return TunnResult::Done;
}
Err(e) => {
tracing::debug!(remote = ?from, err = ?e, "rate limiter rejected packet");
return e;
}
};
use std::collections::hash_map::Entry;
use ana_gotatun::noise::errors::WireGuardError;
match (self.active_tunnels.entry(from), parsed_packet) {
(Entry::Occupied(mut occupied_entry), p) => {
let (peer_static, tunn) = occupied_entry.get_mut();
if !self.authz.is_authorized(now, peer_static.as_bytes()) {
tracing::debug!(remote = ?from, "rejected packet from unauthorized peer");
return TunnResult::Err(WireGuardError::UnexpectedPacket);
}
Self::handle_incoming_and_drain_queue(send_to_network, p, tunn)
}
(e, WgKind::HandshakeInit(wg_init)) => {
let peer = match parse_handshake_anon(
&self.static_private,
&self.static_public,
&wg_init,
) {
Ok(v) => v,
Err(e) => {
tracing::debug!(remote = ?from, err = ?e, "failed to parse handshake init");
return TunnResult::from(e);
}
};
if !self.authz.is_authorized(now, &peer.peer_static_public) {
tracing::debug!(remote = ?from, "rejected handshake from unauthorized peer");
return TunnResult::Err(WireGuardError::UnexpectedPacket);
}
tracing::debug!(remote = ?from, "accepted new handshake, inserting tunnel");
let peer_static = x25519::PublicKey::from(peer.peer_static_public);
let mut tunn = Tunn::new(
self.static_private.clone(),
peer_static,
None,
None,
0,
self.rate_limiter.clone(),
from,
);
let res = Self::handle_incoming_and_drain_queue(
send_to_network,
WgKind::HandshakeInit(wg_init),
&mut tunn,
);
e.insert_entry((peer_static, tunn));
res
}
(_, _p) => {
tracing::debug!(remote = ?from, "received unexpected packet kind for new entry");
TunnResult::Err(WireGuardError::InvalidPacket)
}
}
}
#[tracing::instrument(skip_all, fields(remote = %to))]
pub fn handle_outgoing_packet(&mut self, packet: Packet, to: SocketAddr) -> Option<WgKind> {
let Some((_, tunn)) = self.active_tunnels.get_mut(&to) else {
tracing::error!(to=?to, "No tunnel for outgoing packet found.");
return None;
};
tunn.handle_outgoing_packet(packet.into_bytes())
}
pub fn update_timers(&mut self) -> Vec<(SocketAddr, WgKind)> {
let mut res = vec![];
self.active_tunnels.retain(|k, (_, tunn)| {
match tunn.update_timers() {
Ok(Some(wg)) => res.push((*k, wg)),
Ok(None) => {},
Err(e) => tracing::error!(err=?e, remote_sockaddr=?k, "error when updating timers on tunnel"),
}
!tunn.is_expired()
});
res
}
fn handle_incoming_and_drain_queue(
q: &mut VecDeque<WgKind>,
p: WgKind,
tunn: &mut Tunn,
) -> TunnResult {
let r = match tunn.handle_incoming_packet(p) {
TunnResult::WriteToNetwork(p) => {
q.push_back(p);
TunnResult::Done
}
TunnResult::WriteToTunnel(p) if p.is_empty() => TunnResult::Done,
r => r,
};
for p in tunn.get_queued_packets() {
q.push_back(p);
}
r
}
}
pub trait SnapTunAuthorization: Send + Sync {
fn is_authorized(&self, now: Instant, identity: &[u8; 32]) -> bool;
}
#[cfg(test)]
mod tests {
use std::{collections::VecDeque, net::SocketAddr, sync::Arc};
use ana_gotatun::{
noise::{Tunn, TunnResult, rate_limiter::RateLimiter},
packet::{IpNextProtocol, Packet, WgKind},
x25519,
};
use zerocopy::IntoBytes;
use crate::{
scion_packet::{Scion, ScionHeader},
server::{SnapTunAuthorization, SnapTunServer},
};
type ResultT = Result<(), Box<dyn std::error::Error>>;
struct TrivialAuthz;
impl SnapTunAuthorization for TrivialAuthz {
fn is_authorized(&self, _now: std::time::Instant, _ident: &[u8; 32]) -> bool {
true
}
}
#[test]
fn connect_with_multiple_clients() -> ResultT {
let sockaddr_client0: SocketAddr = "192.168.1.1:1234".parse().unwrap();
let static_client0 = x25519::StaticSecret::from([0u8; 32]);
let sockaddr_client1: SocketAddr = "192.168.1.2:4321".parse().unwrap();
let static_client1 = x25519::StaticSecret::from([1u8; 32]);
let sockaddr_server: SocketAddr = "10.0.0.1:5001".parse().unwrap();
let static_server = x25519::StaticSecret::from([2u8; 32]);
let static_server_public = x25519::PublicKey::from(&static_server);
let rate_limiter = Arc::new(RateLimiter::new(&static_server_public, 100));
let mut snaptun_server =
SnapTunServer::new(static_server, rate_limiter.clone(), Arc::new(TrivialAuthz));
let mut send_to_network = VecDeque::<WgKind>::new();
let test_payload0 = [b'T', b'E', b'S', b'T', b'0'];
let test_payload1 = [b'T', b'E', b'S', b'T', b'1'];
let test_packet0 = Scion {
header: ScionHeader::new(
0, 0xAA, 0xABCDE, test_payload0.len() as _, IpNextProtocol::Udp,
7, 0x0123_4567_89AB_CDEF,
0xFEDC_BA98_7654_3210,
),
payload: test_payload0,
};
let test_packet1 = Scion {
header: test_packet0.header,
payload: test_payload1,
};
let test_packet0 = Packet::copy_from(test_packet0.as_bytes());
let test_packet1 = Packet::copy_from(test_packet1.as_bytes());
let mut tunn_client0 = Tunn::new(
static_client0,
static_server_public,
None,
None,
0,
rate_limiter.clone(),
sockaddr_server,
);
let mut tunn_client1 = Tunn::new(
static_client1,
static_server_public,
None,
None,
0,
rate_limiter,
sockaddr_server,
);
let Some(WgKind::HandshakeInit(hs_init)) =
tunn_client0.handle_outgoing_packet(Packet::copy_from(&test_packet0))
else {
panic!("expected handshake init")
};
snaptun_server.handle_incoming_packet(
Packet::copy_from(hs_init.as_bytes()),
sockaddr_client0,
&mut send_to_network,
);
dispatch_one(&mut tunn_client0, &mut send_to_network);
assert_eq!(
tunn_client0.get_initiator_remote_sockaddr(),
Some(sockaddr_client0)
);
let Some(WgKind::HandshakeInit(hs_init)) =
tunn_client1.handle_outgoing_packet(Packet::copy_from(&test_packet1))
else {
panic!("expected handshake init")
};
snaptun_server.handle_incoming_packet(
Packet::copy_from(hs_init.as_bytes()),
sockaddr_client1,
&mut send_to_network,
);
dispatch_one(&mut tunn_client1, &mut send_to_network);
assert_eq!(
tunn_client1.get_initiator_remote_sockaddr(),
Some(sockaddr_client1)
);
let Some(WgKind::Data(p)) = tunn_client0.get_queued_packets().next() else {
panic!("expected packet to be queued");
};
let TunnResult::WriteToTunnel(p) = snaptun_server.handle_incoming_packet(
Packet::copy_from(p.as_bytes()),
sockaddr_client0,
&mut send_to_network,
) else {
panic!("Expected packet to be processed")
};
assert_eq!(p.as_bytes(), test_packet0.as_bytes());
let Some(WgKind::Data(p1)) = tunn_client1.get_queued_packets().next() else {
panic!("expected packet to be queued");
};
let TunnResult::WriteToTunnel(p1) = snaptun_server.handle_incoming_packet(
Packet::copy_from(p1.as_bytes()),
sockaddr_client1,
&mut send_to_network,
) else {
panic!("expected packet to be received on server side");
};
assert_eq!(p1.as_bytes(), test_packet1.as_bytes());
let res = snaptun_server.handle_outgoing_packet(p, sockaddr_client1);
let Some(p @ WgKind::Data(_)) = res else {
panic!("expected packet to be sent back to client")
};
let TunnResult::WriteToTunnel(p) = tunn_client1.handle_incoming_packet(p) else {
panic!("expected packet to be sent back to client")
};
assert_eq!(p.as_bytes(), test_packet0.as_bytes());
Ok(())
}
fn dispatch_one(tunn: &mut Tunn, packets: &mut VecDeque<WgKind>) -> TunnResult {
if let Some(p) = packets.pop_front() {
return tunn.handle_incoming_packet(p);
}
TunnResult::Done
}
}