1use std::{
22 collections::{HashMap, VecDeque},
23 net::SocketAddr,
24 sync::Arc,
25 time::Instant,
26};
27
28use ana_gotatun::{
29 noise::{Tunn, TunnResult, handshake::parse_handshake_anon, rate_limiter::RateLimiter},
30 packet::{Packet, WgKind},
31 x25519,
32};
33
34pub struct SnapTunServer<T> {
85 static_private: x25519::StaticSecret,
86 static_public: x25519::PublicKey,
87 active_tunnels: HashMap<SocketAddr, (x25519::PublicKey, Tunn)>,
88 rate_limiter: Arc<RateLimiter>,
89 authz: Arc<T>,
90}
91
92impl<T: SnapTunAuthorization> SnapTunServer<T> {
93 pub fn new(
95 static_private: x25519::StaticSecret,
96 rate_limiter: Arc<RateLimiter>,
97 authz: Arc<T>,
98 ) -> Self {
99 let static_public = x25519::PublicKey::from(&static_private);
100 Self {
101 static_private,
102 static_public,
103 active_tunnels: Default::default(),
104 rate_limiter,
105 authz,
106 }
107 }
108
109 #[tracing::instrument(skip_all, fields(remote = %from))]
120 pub fn handle_incoming_packet(
121 &mut self,
122 packet: Packet,
123 from: SocketAddr,
124 send_to_network: &mut VecDeque<WgKind>,
125 ) -> TunnResult {
126 let now = Instant::now();
127
128 let parsed_packet = match self.rate_limiter.verify_packet(from.ip(), packet) {
129 Ok(p) => p,
130 Err(TunnResult::WriteToNetwork(c)) => {
131 tracing::debug!(remote = ?from, "rate limiter issued cookie reply");
132 send_to_network.push_back(c);
133 return TunnResult::Done;
134 }
135 Err(e) => {
136 tracing::debug!(remote = ?from, err = ?e, "rate limiter rejected packet");
137 return e;
138 }
139 };
140
141 use std::collections::hash_map::Entry;
142
143 use ana_gotatun::noise::errors::WireGuardError;
144 match (self.active_tunnels.entry(from), parsed_packet) {
145 (Entry::Occupied(mut occupied_entry), p) => {
146 let (peer_static, tunn) = occupied_entry.get_mut();
147 if !self.authz.is_authorized(now, peer_static.as_bytes()) {
154 tracing::debug!(remote = ?from, "rejected packet from unauthorized peer");
155 return TunnResult::Err(WireGuardError::UnexpectedPacket);
156 }
157 Self::handle_incoming_and_drain_queue(send_to_network, p, tunn)
158 }
159 (e, WgKind::HandshakeInit(wg_init)) => {
160 let peer = match parse_handshake_anon(
161 &self.static_private,
162 &self.static_public,
163 &wg_init,
164 ) {
165 Ok(v) => v,
166 Err(e) => {
167 tracing::debug!(remote = ?from, err = ?e, "failed to parse handshake init");
168 return TunnResult::from(e);
169 }
170 };
171
172 if !self.authz.is_authorized(now, &peer.peer_static_public) {
178 tracing::debug!(remote = ?from, "rejected handshake from unauthorized peer");
179 return TunnResult::Err(WireGuardError::UnexpectedPacket);
180 }
181 tracing::debug!(remote = ?from, "accepted new handshake, inserting tunnel");
182 let peer_static = x25519::PublicKey::from(peer.peer_static_public);
183 let mut tunn = Tunn::new(
184 self.static_private.clone(),
185 peer_static,
186 None,
187 None,
188 0,
189 self.rate_limiter.clone(),
190 from,
191 );
192 let res = Self::handle_incoming_and_drain_queue(
193 send_to_network,
194 WgKind::HandshakeInit(wg_init),
195 &mut tunn,
196 );
197 e.insert_entry((peer_static, tunn));
198 res
199 }
200 (_, _p) => {
201 tracing::debug!(remote = ?from, "received unexpected packet kind for new entry");
202 TunnResult::Err(WireGuardError::InvalidPacket)
203 }
204 }
205 }
206
207 #[tracing::instrument(skip_all, fields(remote = %to))]
210 pub fn handle_outgoing_packet(&mut self, packet: Packet, to: SocketAddr) -> Option<WgKind> {
211 let Some((_, tunn)) = self.active_tunnels.get_mut(&to) else {
212 tracing::error!(to=?to, "No tunnel for outgoing packet found.");
213 return None;
214 };
215 tunn.handle_outgoing_packet(packet.into_bytes())
216 }
217
218 pub fn update_timers(&mut self) -> Vec<(SocketAddr, WgKind)> {
224 let mut res = vec![];
225 self.active_tunnels.retain(|k, (_, tunn)| {
226 match tunn.update_timers() {
227 Ok(Some(wg)) => res.push((*k, wg)),
228 Ok(None) => {},
229 Err(e) => tracing::error!(err=?e, remote_sockaddr=?k, "error when updating timers on tunnel"),
230 }
231
232 !tunn.is_expired()
233 });
234 res
235 }
236
237 fn handle_incoming_and_drain_queue(
238 q: &mut VecDeque<WgKind>,
239 p: WgKind,
240 tunn: &mut Tunn,
241 ) -> TunnResult {
242 let r = match tunn.handle_incoming_packet(p) {
243 TunnResult::WriteToNetwork(p) => {
244 q.push_back(p);
245 TunnResult::Done
246 }
247 TunnResult::WriteToTunnel(p) if p.is_empty() => TunnResult::Done,
249 r => r,
250 };
251 for p in tunn.get_queued_packets() {
252 q.push_back(p);
253 }
254 r
255 }
256}
257
258pub trait SnapTunAuthorization: Send + Sync {
260 fn is_authorized(&self, now: Instant, identity: &[u8; 32]) -> bool;
262}
263
264#[cfg(test)]
265mod tests {
266 use std::{collections::VecDeque, net::SocketAddr, sync::Arc};
267
268 use ana_gotatun::{
269 noise::{Tunn, TunnResult, rate_limiter::RateLimiter},
270 packet::{IpNextProtocol, Packet, WgKind},
271 x25519,
272 };
273 use zerocopy::IntoBytes;
274
275 use crate::{
276 scion_packet::{Scion, ScionHeader},
277 server::{SnapTunAuthorization, SnapTunServer},
278 };
279
280 type ResultT = Result<(), Box<dyn std::error::Error>>;
281
282 struct TrivialAuthz;
283
284 impl SnapTunAuthorization for TrivialAuthz {
285 fn is_authorized(&self, _now: std::time::Instant, _ident: &[u8; 32]) -> bool {
286 true
287 }
288 }
289
290 #[test]
291 fn connect_with_multiple_clients() -> ResultT {
292 let sockaddr_client0: SocketAddr = "192.168.1.1:1234".parse().unwrap();
293 let static_client0 = x25519::StaticSecret::from([0u8; 32]);
294 let sockaddr_client1: SocketAddr = "192.168.1.2:4321".parse().unwrap();
295 let static_client1 = x25519::StaticSecret::from([1u8; 32]);
296 let sockaddr_server: SocketAddr = "10.0.0.1:5001".parse().unwrap();
297 let static_server = x25519::StaticSecret::from([2u8; 32]);
298 let static_server_public = x25519::PublicKey::from(&static_server);
299
300 let rate_limiter = Arc::new(RateLimiter::new(&static_server_public, 100));
301 let mut snaptun_server =
302 SnapTunServer::new(static_server, rate_limiter.clone(), Arc::new(TrivialAuthz));
303
304 let mut send_to_network = VecDeque::<WgKind>::new();
305
306 let test_payload0 = [b'T', b'E', b'S', b'T', b'0'];
307 let test_payload1 = [b'T', b'E', b'S', b'T', b'1'];
308 let test_packet0 = Scion {
309 header: ScionHeader::new(
310 0, 0xAA, 0xABCDE, test_payload0.len() as _, IpNextProtocol::Udp,
315 7, 0x0123_4567_89AB_CDEF,
317 0xFEDC_BA98_7654_3210,
318 ),
319 payload: test_payload0,
320 };
321 let test_packet1 = Scion {
322 header: test_packet0.header,
323 payload: test_payload1,
324 };
325 let test_packet0 = Packet::copy_from(test_packet0.as_bytes());
326 let test_packet1 = Packet::copy_from(test_packet1.as_bytes());
327
328 let mut tunn_client0 = Tunn::new(
329 static_client0,
330 static_server_public,
331 None,
332 None,
333 0,
334 rate_limiter.clone(),
335 sockaddr_server,
336 );
337
338 let mut tunn_client1 = Tunn::new(
339 static_client1,
340 static_server_public,
341 None,
342 None,
343 0,
344 rate_limiter,
345 sockaddr_server,
346 );
347
348 let Some(WgKind::HandshakeInit(hs_init)) =
350 tunn_client0.handle_outgoing_packet(Packet::copy_from(&test_packet0))
351 else {
352 panic!("expected handshake init")
353 };
354
355 snaptun_server.handle_incoming_packet(
356 Packet::copy_from(hs_init.as_bytes()),
357 sockaddr_client0,
358 &mut send_to_network,
359 );
360
361 dispatch_one(&mut tunn_client0, &mut send_to_network);
362 assert_eq!(
363 tunn_client0.get_initiator_remote_sockaddr(),
364 Some(sockaddr_client0)
365 );
366
367 let Some(WgKind::HandshakeInit(hs_init)) =
369 tunn_client1.handle_outgoing_packet(Packet::copy_from(&test_packet1))
370 else {
371 panic!("expected handshake init")
372 };
373
374 snaptun_server.handle_incoming_packet(
375 Packet::copy_from(hs_init.as_bytes()),
376 sockaddr_client1,
377 &mut send_to_network,
378 );
379
380 dispatch_one(&mut tunn_client1, &mut send_to_network);
381 assert_eq!(
382 tunn_client1.get_initiator_remote_sockaddr(),
383 Some(sockaddr_client1)
384 );
385
386 let Some(WgKind::Data(p)) = tunn_client0.get_queued_packets().next() else {
388 panic!("expected packet to be queued");
389 };
390
391 let TunnResult::WriteToTunnel(p) = snaptun_server.handle_incoming_packet(
392 Packet::copy_from(p.as_bytes()),
393 sockaddr_client0,
394 &mut send_to_network,
395 ) else {
396 panic!("Expected packet to be processed")
397 };
398 assert_eq!(p.as_bytes(), test_packet0.as_bytes());
399
400 let Some(WgKind::Data(p1)) = tunn_client1.get_queued_packets().next() else {
404 panic!("expected packet to be queued");
405 };
406
407 let TunnResult::WriteToTunnel(p1) = snaptun_server.handle_incoming_packet(
408 Packet::copy_from(p1.as_bytes()),
409 sockaddr_client1,
410 &mut send_to_network,
411 ) else {
412 panic!("expected packet to be received on server side");
413 };
414 assert_eq!(p1.as_bytes(), test_packet1.as_bytes());
415
416 let res = snaptun_server.handle_outgoing_packet(p, sockaddr_client1);
418 let Some(p @ WgKind::Data(_)) = res else {
419 panic!("expected packet to be sent back to client")
420 };
421
422 let TunnResult::WriteToTunnel(p) = tunn_client1.handle_incoming_packet(p) else {
423 panic!("expected packet to be sent back to client")
424 };
425
426 assert_eq!(p.as_bytes(), test_packet0.as_bytes());
427
428 Ok(())
429 }
430
431 fn dispatch_one(tunn: &mut Tunn, packets: &mut VecDeque<WgKind>) -> TunnResult {
432 if let Some(p) = packets.pop_front() {
433 return tunn.handle_incoming_packet(p);
434 }
435 TunnResult::Done
436 }
437}