wireguard_netstack/
wireguard.rs

1//! WireGuard tunnel implementation using gotatun.
2//!
3//! This module wraps gotatun's `Tunn` struct and manages the UDP transport
4//! for sending/receiving encrypted WireGuard packets.
5
6use bytes::BytesMut;
7use gotatun::noise::{Tunn, TunnResult};
8use gotatun::packet::Packet;
9use gotatun::x25519::{PublicKey, StaticSecret};
10use parking_lot::Mutex;
11use zerocopy::IntoBytes;
12use std::net::{Ipv4Addr, SocketAddr};
13use std::sync::Arc;
14use std::time::Duration;
15use tokio::net::UdpSocket;
16use tokio::sync::mpsc;
17
18use crate::error::{Error, Result};
19
20/// Configuration for the WireGuard tunnel.
21#[derive(Clone)]
22pub struct WireGuardConfig {
23    /// Our private key (32 bytes).
24    pub private_key: [u8; 32],
25    /// Peer's public key (32 bytes).
26    pub peer_public_key: [u8; 32],
27    /// Peer's endpoint (IP:port).
28    pub peer_endpoint: SocketAddr,
29    /// Our IP address inside the tunnel.
30    pub tunnel_ip: Ipv4Addr,
31    /// Optional preshared key for additional security.
32    pub preshared_key: Option<[u8; 32]>,
33    /// Keepalive interval in seconds (0 = disabled).
34    pub keepalive_seconds: Option<u16>,
35}
36
37/// A WireGuard tunnel that encrypts/decrypts IP packets.
38pub struct WireGuardTunnel {
39    /// The underlying gotatun tunnel.
40    tunn: Mutex<Tunn>,
41    /// UDP socket for sending/receiving encrypted packets.
42    udp_socket: Arc<UdpSocket>,
43    /// Peer's endpoint address.
44    peer_endpoint: SocketAddr,
45    /// Our tunnel IP address.
46    tunnel_ip: Ipv4Addr,
47    /// Channel to send received IP packets.
48    incoming_tx: mpsc::Sender<BytesMut>,
49    /// Channel to receive IP packets to send.
50    outgoing_rx: tokio::sync::Mutex<mpsc::Receiver<BytesMut>>,
51    /// Channel to receive incoming IP packets.
52    incoming_rx: Mutex<Option<mpsc::Receiver<BytesMut>>>,
53    /// Channel to send IP packets for encryption.
54    outgoing_tx: mpsc::Sender<BytesMut>,
55}
56
57impl WireGuardTunnel {
58    /// Create a new WireGuard tunnel with the given configuration.
59    pub async fn new(config: WireGuardConfig) -> Result<Arc<Self>> {
60        // Create the cryptographic keys
61        let private_key = StaticSecret::from(config.private_key);
62        let peer_public_key = PublicKey::from(config.peer_public_key);
63
64        // Create the tunnel
65        let tunn = Tunn::new(
66            private_key,
67            peer_public_key,
68            config.preshared_key,
69            config.keepalive_seconds,
70            rand::random::<u32>() >> 8, // Random index
71            None,                        // No rate limiter for client
72        );
73
74        // Bind UDP socket to any available port
75        let udp_socket = UdpSocket::bind("0.0.0.0:0").await?;
76
77        // Increase socket receive buffer to avoid packet loss
78        let sock_ref = socket2::SockRef::from(&udp_socket);
79        if let Err(e) = sock_ref.set_recv_buffer_size(1024 * 1024) {
80            // 1MB buffer
81            log::warn!("Failed to set UDP recv buffer size: {}", e);
82        }
83        if let Err(e) = sock_ref.set_send_buffer_size(1024 * 1024) {
84            // 1MB buffer
85            log::warn!("Failed to set UDP send buffer size: {}", e);
86        }
87        log::info!("UDP recv buffer size: {:?}", sock_ref.recv_buffer_size());
88        log::info!("UDP send buffer size: {:?}", sock_ref.send_buffer_size());
89
90        log::info!(
91            "WireGuard UDP socket bound to {}",
92            udp_socket.local_addr()?
93        );
94
95        // Create channels for packet communication
96        // incoming: packets received from the tunnel (decrypted)
97        // outgoing: packets to send through the tunnel (to be encrypted)
98        let (incoming_tx, incoming_rx) = mpsc::channel(256);
99        let (outgoing_tx, outgoing_rx) = mpsc::channel(256);
100
101        let tunnel = Arc::new(Self {
102            tunn: Mutex::new(tunn),
103            udp_socket: Arc::new(udp_socket),
104            peer_endpoint: config.peer_endpoint,
105            tunnel_ip: config.tunnel_ip,
106            incoming_tx,
107            incoming_rx: Mutex::new(Some(incoming_rx)),
108            outgoing_tx,
109            outgoing_rx: tokio::sync::Mutex::new(outgoing_rx),
110        });
111
112        Ok(tunnel)
113    }
114
115    /// Get our tunnel IP address.
116    pub fn tunnel_ip(&self) -> Ipv4Addr {
117        self.tunnel_ip
118    }
119
120    /// Get the sender for outgoing packets.
121    pub fn outgoing_sender(&self) -> mpsc::Sender<BytesMut> {
122        self.outgoing_tx.clone()
123    }
124
125    /// Get the receiver for incoming packets (takes ownership of the receiver).
126    pub fn take_incoming_receiver(&self) -> Option<mpsc::Receiver<BytesMut>> {
127        self.incoming_rx.lock().take()
128    }
129
130    /// Initiate the WireGuard handshake.
131    pub async fn initiate_handshake(&self) -> Result<()> {
132        log::info!("Initiating WireGuard handshake...");
133
134        let handshake_init = {
135            let mut tunn = self.tunn.lock();
136            tunn.format_handshake_initiation(false)
137        };
138
139        if let Some(packet) = handshake_init {
140            // Convert Packet<WgHandshakeInit> to bytes
141            let data = packet.as_bytes();
142            self.udp_socket.send_to(data, self.peer_endpoint).await?;
143            log::debug!("Sent handshake initiation ({} bytes)", data.len());
144        }
145
146        Ok(())
147    }
148
149    /// Send an IP packet through the tunnel (encrypts and sends via UDP).
150    pub async fn send_ip_packet(&self, packet: BytesMut) -> Result<()> {
151        let encrypted = {
152            let mut tunn = self.tunn.lock();
153            let pkt = Packet::from_bytes(packet);
154            tunn.handle_outgoing_packet(pkt)
155        };
156
157        if let Some(wg_packet) = encrypted {
158            // Convert WgKind to Packet<[u8]> and get bytes
159            let pkt: Packet = wg_packet.into();
160            let data = pkt.as_bytes();
161            self.udp_socket.send_to(data, self.peer_endpoint).await?;
162            log::trace!("Sent encrypted packet ({} bytes)", data.len());
163        }
164
165        Ok(())
166    }
167
168    /// Process a received UDP packet (decrypts and returns IP packet if any).
169    fn process_incoming_udp(&self, data: &[u8]) -> Option<BytesMut> {
170        let packet = Packet::from_bytes(BytesMut::from(data));
171        let wg_packet = match packet.try_into_wg() {
172            Ok(wg) => wg,
173            Err(_) => {
174                log::warn!("Received non-WireGuard packet");
175                return None;
176            }
177        };
178
179        let mut tunn = self.tunn.lock();
180        match tunn.handle_incoming_packet(wg_packet) {
181            TunnResult::Done => {
182                log::trace!("WG: Packet processed (no output)");
183                None
184            }
185            TunnResult::Err(e) => {
186                log::warn!("WG error: {:?}", e);
187                None
188            }
189            TunnResult::WriteToNetwork(response) => {
190                log::trace!("WG: Sending response packet");
191                // Need to send a response (e.g., handshake response, keepalive)
192                let pkt: Packet = response.into();
193                let data = BytesMut::from(pkt.as_bytes());
194                let socket = self.udp_socket.clone();
195                let endpoint = self.peer_endpoint;
196                tokio::spawn(async move {
197                    if let Err(e) = socket.send_to(&data, endpoint).await {
198                        log::error!("Failed to send response: {}", e);
199                    }
200                });
201
202                // Also try to send any queued packets
203                while let Some(queued) = tunn.next_queued_packet() {
204                    let pkt: Packet = queued.into();
205                    let data = BytesMut::from(pkt.as_bytes());
206                    let socket = self.udp_socket.clone();
207                    let endpoint = self.peer_endpoint;
208                    tokio::spawn(async move {
209                        if let Err(e) = socket.send_to(&data, endpoint).await {
210                            log::error!("Failed to send queued packet: {}", e);
211                        }
212                    });
213                }
214
215                None
216            }
217            TunnResult::WriteToTunnel(decrypted) => {
218                if decrypted.is_empty() {
219                    log::trace!("WG: Received keepalive");
220                    return None;
221                }
222                let bytes = BytesMut::from(decrypted.as_bytes());
223                log::trace!("WG: Decrypted {} bytes", bytes.len());
224                Some(bytes)
225            }
226        }
227    }
228
229    /// Run the tunnel's receive loop (listens for UDP packets and decrypts them).
230    pub async fn run_receive_loop(self: &Arc<Self>) -> Result<()> {
231        let mut buf = vec![0u8; 65535];
232
233        loop {
234            match self.udp_socket.recv_from(&mut buf).await {
235                Ok((len, from)) => {
236                    if from != self.peer_endpoint {
237                        log::warn!("Received packet from unknown peer: {}", from);
238                        continue;
239                    }
240
241                    log::trace!("Received UDP packet ({} bytes) from {}", len, from);
242
243                    if let Some(ip_packet) = self.process_incoming_udp(&buf[..len]) {
244                        if self.incoming_tx.send(ip_packet).await.is_err() {
245                            log::error!("Incoming channel closed");
246                            break;
247                        }
248                    }
249                }
250                Err(e) => {
251                    log::error!("UDP receive error: {}", e);
252                    break;
253                }
254            }
255        }
256
257        Ok(())
258    }
259
260    /// Run the tunnel's send loop (encrypts and sends IP packets).
261    pub async fn run_send_loop(self: &Arc<Self>) -> Result<()> {
262        let mut outgoing_rx = self.outgoing_rx.lock().await;
263
264        while let Some(packet) = outgoing_rx.recv().await {
265            if let Err(e) = self.send_ip_packet(packet).await {
266                log::error!("Failed to send packet: {}", e);
267            }
268        }
269
270        Ok(())
271    }
272
273    /// Run the tunnel's timer loop (handles keepalives and handshake retries).
274    pub async fn run_timer_loop(self: &Arc<Self>) -> Result<()> {
275        let mut interval = tokio::time::interval(Duration::from_millis(250));
276
277        loop {
278            interval.tick().await;
279
280            let packets_to_send: Vec<Vec<u8>> = {
281                let mut tunn = self.tunn.lock();
282                match tunn.update_timers() {
283                    Ok(Some(packet)) => {
284                        let pkt: Packet = packet.into();
285                        vec![pkt.as_bytes().to_vec()]
286                    }
287                    Ok(None) => vec![],
288                    Err(e) => {
289                        log::trace!("Timer error (may be normal): {:?}", e);
290                        vec![]
291                    }
292                }
293            };
294
295            for packet in packets_to_send {
296                if let Err(e) = self.udp_socket.send_to(&packet, self.peer_endpoint).await {
297                    log::error!("Failed to send timer packet: {}", e);
298                }
299            }
300        }
301    }
302
303    /// Wait for the handshake to complete (with timeout).
304    pub async fn wait_for_handshake(&self, timeout_duration: Duration) -> Result<()> {
305        let start = std::time::Instant::now();
306
307        loop {
308            {
309                let tunn = self.tunn.lock();
310                // Check if we have an active session - time_since_handshake is Some when session is established
311                let (time_since_handshake, _tx_bytes, _rx_bytes, _, _) = tunn.stats();
312                if time_since_handshake.is_some() {
313                    log::info!("WireGuard handshake completed!");
314                    return Ok(());
315                }
316            }
317
318            if start.elapsed() > timeout_duration {
319                return Err(Error::HandshakeTimeout(timeout_duration));
320            }
321
322            tokio::time::sleep(Duration::from_millis(50)).await;
323        }
324    }
325}