Skip to main content

snap_tun/
server.rs

1// Copyright 2026 Anapaya Systems
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//   http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//! The server of the SNAPtun protocol.
15//!
16//! As the underlying protocol is symmetric (both peers can act as
17//! initiator/responders that establish a session), technically, there is no
18//! server. The term "server" here just refers to and endpoint that manages
19//! multiple peers.
20
21use std::{
22    collections::{HashMap, VecDeque},
23    net::SocketAddr,
24    sync::Arc,
25    time::Instant,
26};
27
28use ana_gotatun::{
29    noise::{Tunn, TunnResult, handshake::parse_handshake_anon, rate_limiter::RateLimiter},
30    packet::{Packet, WgKind},
31    x25519,
32};
33
34/// The [SnapTunServer] manages one [Tunn] per remote socket address.
35///
36/// The main structural difference between WireGuard (R) and snaptun-ng is that
37/// there is a one-to-one relation between a remote socket address (of the
38/// initiator) and a tunnel. The [SnapTunServer] manages that relation.
39///
40/// ## Scaling
41///
42/// The main methods [SnapTunServer::handle_incoming_packet],
43/// [SnapTunServer::handle_outgoing_packet], and
44/// [SnapTunServer::update_timers] all require an exclusive reference to the
45/// internal state. The reason is that processing both, incoming and outgoing
46/// packets requires access to the session state.
47///
48/// One simple way to achieve load distribution across different cores/threads
49/// is to shard over multiple [SnapTunServer]-instances based on a hash of the
50/// remote socket address.
51///
52/// ## Future improvements
53///
54/// * Separate incoming and outgoing code paths and optimistically lock the session state.
55///
56/// ## How to use
57///
58/// The [SnapTunServer] is i/o-free; i.e. it only manages state. The following
59/// is a pseudo-code like description of the simplest i/o-layer integration:
60///
61/// ```text
62/// let mut server = SnapTunServer::new(/*...*/);
63/// let mut send_to_network = VecDequeue::new();
64/// let mut current_sockaddr = ;
65/// loop {
66///   switch {
67///     (network_packet, sockaddr) = network_socket => {
68///       server.handle_incoming_packet(/*...*/);
69///       /* dispatch packets to tunnel if necessary */
70///     }
71///     tunnel_packet = tunnel_socket => {
72///       server.handle_outgoing_packet(/*...*/);
73///     }
74///     timer = tick(250ms) => {
75///       server.update_timers();
76///     }
77///   }
78///   // dispatch packets to network
79///   for p in send_to_network {
80///     network_socket.send(sockaddr, p);
81///   }
82/// }
83/// ```
84pub struct SnapTunServer<T> {
85    static_private: x25519::StaticSecret,
86    static_public: x25519::PublicKey,
87    active_tunnels: HashMap<SocketAddr, (x25519::PublicKey, Tunn)>,
88    rate_limiter: Arc<RateLimiter>,
89    authz: Arc<T>,
90}
91
92impl<T: SnapTunAuthorization> SnapTunServer<T> {
93    /// Creates a new [SnapTunServer] instance.
94    pub fn new(
95        static_private: x25519::StaticSecret,
96        rate_limiter: Arc<RateLimiter>,
97        authz: Arc<T>,
98    ) -> Self {
99        let static_public = x25519::PublicKey::from(&static_private);
100        Self {
101            static_private,
102            static_public,
103            active_tunnels: Default::default(),
104            rate_limiter,
105            authz,
106        }
107    }
108
109    /// Handle incoming packet for a tunnel assocated with remote socket address
110    /// `from`.
111    ///
112    /// This method _never_ returns [TunnResult::WriteToNetwork]. Instead,
113    /// it codifies the expected protocol behavior which is that, upon receiving
114    /// a packet from the remote, the queue of outgoing packets is completely
115    /// drained.
116    ///
117    /// If the rate limiter signals that the server is under load, at most one
118    /// packet is added to the queue.
119    #[tracing::instrument(skip_all, fields(remote = %from))]
120    pub fn handle_incoming_packet(
121        &mut self,
122        packet: Packet,
123        from: SocketAddr,
124        send_to_network: &mut VecDeque<WgKind>,
125    ) -> TunnResult {
126        let now = Instant::now();
127
128        let parsed_packet = match self.rate_limiter.verify_packet(from.ip(), packet) {
129            Ok(p) => p,
130            Err(TunnResult::WriteToNetwork(c)) => {
131                tracing::debug!(remote = ?from, "rate limiter issued cookie reply");
132                send_to_network.push_back(c);
133                return TunnResult::Done;
134            }
135            Err(e) => {
136                tracing::debug!(remote = ?from, err = ?e, "rate limiter rejected packet");
137                return e;
138            }
139        };
140
141        use std::collections::hash_map::Entry;
142
143        use ana_gotatun::noise::errors::WireGuardError;
144        match (self.active_tunnels.entry(from), parsed_packet) {
145            (Entry::Occupied(mut occupied_entry), p) => {
146                let (peer_static, tunn) = occupied_entry.get_mut();
147                // TODO(dsd): At the moment, this keeps a tunnel alive even
148                // though the processing might fail, but gives the authorization
149                // layer a chance to block incomding packets in case an identity
150                // is unauthorized.
151                //
152                // Will fix later.
153                if !self.authz.is_authorized(now, peer_static.as_bytes()) {
154                    tracing::debug!(remote = ?from, "rejected packet from unauthorized peer");
155                    return TunnResult::Err(WireGuardError::UnexpectedPacket);
156                }
157                Self::handle_incoming_and_drain_queue(send_to_network, p, tunn)
158            }
159            (e, WgKind::HandshakeInit(wg_init)) => {
160                let peer = match parse_handshake_anon(
161                    &self.static_private,
162                    &self.static_public,
163                    &wg_init,
164                ) {
165                    Ok(v) => v,
166                    Err(e) => {
167                        tracing::debug!(remote = ?from, err = ?e, "failed to parse handshake init");
168                        return TunnResult::from(e);
169                    }
170                };
171
172                // TODO(dsd): if the socket is occupied, and tunnel.identity !=
173                // peer.identity, then send a cookie and abort
174
175                // TODO(dsd): extend ana-gotatun::Tunn such that peer static
176                // identity can be retrieved
177                if !self.authz.is_authorized(now, &peer.peer_static_public) {
178                    tracing::debug!(remote = ?from, "rejected handshake from unauthorized peer");
179                    return TunnResult::Err(WireGuardError::UnexpectedPacket);
180                }
181                tracing::debug!(remote = ?from, "accepted new handshake, inserting tunnel");
182                let peer_static = x25519::PublicKey::from(peer.peer_static_public);
183                let mut tunn = Tunn::new(
184                    self.static_private.clone(),
185                    peer_static,
186                    None,
187                    None,
188                    0,
189                    self.rate_limiter.clone(),
190                    from,
191                );
192                let res = Self::handle_incoming_and_drain_queue(
193                    send_to_network,
194                    WgKind::HandshakeInit(wg_init),
195                    &mut tunn,
196                );
197                e.insert_entry((peer_static, tunn));
198                res
199            }
200            (_, _p) => {
201                tracing::debug!(remote = ?from, "received unexpected packet kind for new entry");
202                TunnResult::Err(WireGuardError::InvalidPacket)
203            }
204        }
205    }
206
207    /// Handles an outgoing packet sent through the tunnel identified by the
208    /// remote socket address `to`.
209    #[tracing::instrument(skip_all, fields(remote = %to))]
210    pub fn handle_outgoing_packet(&mut self, packet: Packet, to: SocketAddr) -> Option<WgKind> {
211        let Some((_, tunn)) = self.active_tunnels.get_mut(&to) else {
212            tracing::error!(to=?to, "No tunnel for outgoing packet found.");
213            return None;
214        };
215        tunn.handle_outgoing_packet(packet.into_bytes())
216    }
217
218    /// Update timers of all tunnels. Generate corresponding keepalive or
219    /// session handshake initializations.
220    ///
221    /// As a result of this call, all expired tunnels are removed. Note that
222    /// this is not the same as unauthorized tunnels.
223    pub fn update_timers(&mut self) -> Vec<(SocketAddr, WgKind)> {
224        let mut res = vec![];
225        self.active_tunnels.retain(|k, (_, tunn)| {
226            match tunn.update_timers() {
227                Ok(Some(wg)) => res.push((*k, wg)),
228                Ok(None) => {},
229                Err(e) => tracing::error!(err=?e, remote_sockaddr=?k, "error when updating timers on tunnel"),
230            }
231
232            !tunn.is_expired()
233        });
234        res
235    }
236
237    fn handle_incoming_and_drain_queue(
238        q: &mut VecDeque<WgKind>,
239        p: WgKind,
240        tunn: &mut Tunn,
241    ) -> TunnResult {
242        let r = match tunn.handle_incoming_packet(p) {
243            TunnResult::WriteToNetwork(p) => {
244                q.push_back(p);
245                TunnResult::Done
246            }
247            // keep alive
248            TunnResult::WriteToTunnel(p) if p.is_empty() => TunnResult::Done,
249            r => r,
250        };
251        for p in tunn.get_queued_packets() {
252            q.push_back(p);
253        }
254        r
255    }
256}
257
258/// Authorization layer for the snaptun server.
259pub trait SnapTunAuthorization: Send + Sync {
260    /// Returns true iff the peer is allowed to send traffic to the server.
261    fn is_authorized(&self, now: Instant, identity: &[u8; 32]) -> bool;
262}
263
264#[cfg(test)]
265mod tests {
266    use std::{collections::VecDeque, net::SocketAddr, sync::Arc};
267
268    use ana_gotatun::{
269        noise::{Tunn, TunnResult, rate_limiter::RateLimiter},
270        packet::{IpNextProtocol, Packet, WgKind},
271        x25519,
272    };
273    use zerocopy::IntoBytes;
274
275    use crate::{
276        scion_packet::{Scion, ScionHeader},
277        server::{SnapTunAuthorization, SnapTunServer},
278    };
279
280    type ResultT = Result<(), Box<dyn std::error::Error>>;
281
282    struct TrivialAuthz;
283
284    impl SnapTunAuthorization for TrivialAuthz {
285        fn is_authorized(&self, _now: std::time::Instant, _ident: &[u8; 32]) -> bool {
286            true
287        }
288    }
289
290    #[test]
291    fn connect_with_multiple_clients() -> ResultT {
292        let sockaddr_client0: SocketAddr = "192.168.1.1:1234".parse().unwrap();
293        let static_client0 = x25519::StaticSecret::from([0u8; 32]);
294        let sockaddr_client1: SocketAddr = "192.168.1.2:4321".parse().unwrap();
295        let static_client1 = x25519::StaticSecret::from([1u8; 32]);
296        let sockaddr_server: SocketAddr = "10.0.0.1:5001".parse().unwrap();
297        let static_server = x25519::StaticSecret::from([2u8; 32]);
298        let static_server_public = x25519::PublicKey::from(&static_server);
299
300        let rate_limiter = Arc::new(RateLimiter::new(&static_server_public, 100));
301        let mut snaptun_server =
302            SnapTunServer::new(static_server, rate_limiter.clone(), Arc::new(TrivialAuthz));
303
304        let mut send_to_network = VecDeque::<WgKind>::new();
305
306        let test_payload0 = [b'T', b'E', b'S', b'T', b'0'];
307        let test_payload1 = [b'T', b'E', b'S', b'T', b'1'];
308        let test_packet0 = Scion {
309            header: ScionHeader::new(
310                0,                        // version
311                0xAA,                     // traffic_class
312                0xABCDE,                  // flow_id (20 bits)
313                test_payload0.len() as _, // payload_len
314                IpNextProtocol::Udp,
315                7, // hop_count
316                0x0123_4567_89AB_CDEF,
317                0xFEDC_BA98_7654_3210,
318            ),
319            payload: test_payload0,
320        };
321        let test_packet1 = Scion {
322            header: test_packet0.header,
323            payload: test_payload1,
324        };
325        let test_packet0 = Packet::copy_from(test_packet0.as_bytes());
326        let test_packet1 = Packet::copy_from(test_packet1.as_bytes());
327
328        let mut tunn_client0 = Tunn::new(
329            static_client0,
330            static_server_public,
331            None,
332            None,
333            0,
334            rate_limiter.clone(),
335            sockaddr_server,
336        );
337
338        let mut tunn_client1 = Tunn::new(
339            static_client1,
340            static_server_public,
341            None,
342            None,
343            0,
344            rate_limiter,
345            sockaddr_server,
346        );
347
348        /* handshake 0 */
349        let Some(WgKind::HandshakeInit(hs_init)) =
350            tunn_client0.handle_outgoing_packet(Packet::copy_from(&test_packet0))
351        else {
352            panic!("expected handshake init")
353        };
354
355        snaptun_server.handle_incoming_packet(
356            Packet::copy_from(hs_init.as_bytes()),
357            sockaddr_client0,
358            &mut send_to_network,
359        );
360
361        dispatch_one(&mut tunn_client0, &mut send_to_network);
362        assert_eq!(
363            tunn_client0.get_initiator_remote_sockaddr(),
364            Some(sockaddr_client0)
365        );
366
367        /* handshake 1 */
368        let Some(WgKind::HandshakeInit(hs_init)) =
369            tunn_client1.handle_outgoing_packet(Packet::copy_from(&test_packet1))
370        else {
371            panic!("expected handshake init")
372        };
373
374        snaptun_server.handle_incoming_packet(
375            Packet::copy_from(hs_init.as_bytes()),
376            sockaddr_client1,
377            &mut send_to_network,
378        );
379
380        dispatch_one(&mut tunn_client1, &mut send_to_network);
381        assert_eq!(
382            tunn_client1.get_initiator_remote_sockaddr(),
383            Some(sockaddr_client1)
384        );
385
386        /* send C0 -> S */
387        let Some(WgKind::Data(p)) = tunn_client0.get_queued_packets().next() else {
388            panic!("expected packet to be queued");
389        };
390
391        let TunnResult::WriteToTunnel(p) = snaptun_server.handle_incoming_packet(
392            Packet::copy_from(p.as_bytes()),
393            sockaddr_client0,
394            &mut send_to_network,
395        ) else {
396            panic!("Expected packet to be processed")
397        };
398        assert_eq!(p.as_bytes(), test_packet0.as_bytes());
399
400        /* send C1 -> S */
401        // before we can send a packet to client1, we need to send a packet from
402        // client1 so the server starts using the session.
403        let Some(WgKind::Data(p1)) = tunn_client1.get_queued_packets().next() else {
404            panic!("expected packet to be queued");
405        };
406
407        let TunnResult::WriteToTunnel(p1) = snaptun_server.handle_incoming_packet(
408            Packet::copy_from(p1.as_bytes()),
409            sockaddr_client1,
410            &mut send_to_network,
411        ) else {
412            panic!("expected packet to be received on server side");
413        };
414        assert_eq!(p1.as_bytes(), test_packet1.as_bytes());
415
416        /* send S -> C1 */
417        let res = snaptun_server.handle_outgoing_packet(p, sockaddr_client1);
418        let Some(p @ WgKind::Data(_)) = res else {
419            panic!("expected packet to be sent back to client")
420        };
421
422        let TunnResult::WriteToTunnel(p) = tunn_client1.handle_incoming_packet(p) else {
423            panic!("expected packet to be sent back to client")
424        };
425
426        assert_eq!(p.as_bytes(), test_packet0.as_bytes());
427
428        Ok(())
429    }
430
431    fn dispatch_one(tunn: &mut Tunn, packets: &mut VecDeque<WgKind>) -> TunnResult {
432        if let Some(p) = packets.pop_front() {
433            return tunn.handle_incoming_packet(p);
434        }
435        TunnResult::Done
436    }
437}