wireguard_netstack/
wireguard.rs1use bytes::BytesMut;
7use gotatun::noise::rate_limiter::RateLimiter;
8use gotatun::noise::{Tunn, TunnResult};
9use gotatun::packet::Packet;
10use gotatun::x25519::{PublicKey, StaticSecret};
11use parking_lot::Mutex;
12use zerocopy::IntoBytes;
13use std::net::{Ipv4Addr, SocketAddr};
14use std::sync::Arc;
15use std::time::Duration;
16use tokio::net::UdpSocket;
17use tokio::sync::mpsc;
18
19use crate::error::{Error, Result};
20
21#[derive(Clone)]
23pub struct WireGuardConfig {
24 pub private_key: [u8; 32],
26 pub peer_public_key: [u8; 32],
28 pub peer_endpoint: SocketAddr,
30 pub tunnel_ip: Ipv4Addr,
32 pub preshared_key: Option<[u8; 32]>,
34 pub keepalive_seconds: Option<u16>,
36 pub mtu: Option<u16>,
38}
39
40pub struct WireGuardTunnel {
42 tunn: Mutex<Tunn>,
44 udp_socket: Arc<UdpSocket>,
46 peer_endpoint: SocketAddr,
48 tunnel_ip: Ipv4Addr,
50 mtu: u16,
52 incoming_tx: mpsc::Sender<BytesMut>,
54 outgoing_rx: tokio::sync::Mutex<mpsc::Receiver<BytesMut>>,
56 incoming_rx: Mutex<Option<mpsc::Receiver<BytesMut>>>,
58 outgoing_tx: mpsc::Sender<BytesMut>,
60}
61
62impl WireGuardTunnel {
63 pub async fn new(config: WireGuardConfig) -> Result<Arc<Self>> {
65 let private_key = StaticSecret::from(config.private_key);
67 let peer_public_key = PublicKey::from(config.peer_public_key);
68
69 let tunn = Tunn::new(
71 private_key,
72 peer_public_key,
73 config.preshared_key,
74 config.keepalive_seconds,
75 rand::random::<u32>() >> 8, Arc::new(RateLimiter::new(&peer_public_key, 0)),
77 );
78
79 let udp_socket = UdpSocket::bind("0.0.0.0:0").await?;
81
82 let sock_ref = socket2::SockRef::from(&udp_socket);
84 if let Err(e) = sock_ref.set_recv_buffer_size(1024 * 1024) {
85 log::warn!("Failed to set UDP recv buffer size: {}", e);
87 }
88 if let Err(e) = sock_ref.set_send_buffer_size(1024 * 1024) {
89 log::warn!("Failed to set UDP send buffer size: {}", e);
91 }
92 log::info!("UDP recv buffer size: {:?}", sock_ref.recv_buffer_size());
93 log::info!("UDP send buffer size: {:?}", sock_ref.send_buffer_size());
94
95 log::info!(
96 "WireGuard UDP socket bound to {}",
97 udp_socket.local_addr()?
98 );
99
100 let (incoming_tx, incoming_rx) = mpsc::channel(256);
104 let (outgoing_tx, outgoing_rx) = mpsc::channel(256);
105
106 let tunnel = Arc::new(Self {
107 tunn: Mutex::new(tunn),
108 udp_socket: Arc::new(udp_socket),
109 peer_endpoint: config.peer_endpoint,
110 tunnel_ip: config.tunnel_ip,
111 mtu: config.mtu.unwrap_or(460), incoming_tx,
113 incoming_rx: Mutex::new(Some(incoming_rx)),
114 outgoing_tx,
115 outgoing_rx: tokio::sync::Mutex::new(outgoing_rx),
116 });
117
118 Ok(tunnel)
119 }
120
121 pub fn tunnel_ip(&self) -> Ipv4Addr {
123 self.tunnel_ip
124 }
125
126 pub fn mtu(&self) -> u16 {
128 self.mtu
129 }
130
131 pub fn outgoing_sender(&self) -> mpsc::Sender<BytesMut> {
133 self.outgoing_tx.clone()
134 }
135
136 pub fn take_incoming_receiver(&self) -> Option<mpsc::Receiver<BytesMut>> {
138 self.incoming_rx.lock().take()
139 }
140
141 pub fn time_since_last_handshake(&self) -> Option<Duration> {
148 let tunn = self.tunn.lock();
149 tunn.stats().0
150 }
151
152 pub async fn initiate_handshake(&self) -> Result<()> {
154 log::info!("Initiating WireGuard handshake...");
155
156 let handshake_init = {
157 let mut tunn = self.tunn.lock();
158 tunn.format_handshake_initiation(false)
159 };
160
161 if let Some(packet) = handshake_init {
162 let data = packet.as_bytes();
164 self.udp_socket.send_to(data, self.peer_endpoint).await?;
165 log::debug!("Sent handshake initiation ({} bytes)", data.len());
166 }
167
168 Ok(())
169 }
170
171 pub async fn send_ip_packet(&self, packet: BytesMut) -> Result<()> {
173 let encrypted = {
174 let mut tunn = self.tunn.lock();
175 let pkt = Packet::from_bytes(packet);
176 tunn.handle_outgoing_packet(pkt)
177 };
178
179 if let Some(wg_packet) = encrypted {
180 let pkt: Packet = wg_packet.into();
182 let data = pkt.as_bytes();
183 self.udp_socket.send_to(data, self.peer_endpoint).await?;
184 log::trace!("Sent encrypted packet ({} bytes)", data.len());
185 }
186
187 Ok(())
188 }
189
190 fn process_incoming_udp(&self, data: &[u8]) -> Option<BytesMut> {
192 let packet = Packet::from_bytes(BytesMut::from(data));
193 let wg_packet = match packet.try_into_wg() {
194 Ok(wg) => wg,
195 Err(_) => {
196 log::warn!("Received non-WireGuard packet");
197 return None;
198 }
199 };
200
201 let mut tunn = self.tunn.lock();
202 match tunn.handle_incoming_packet(wg_packet) {
203 TunnResult::Done => {
204 log::trace!("WG: Packet processed (no output)");
205 None
206 }
207 TunnResult::Err(e) => {
208 log::warn!("WG error: {:?}", e);
209 None
210 }
211 TunnResult::WriteToNetwork(response) => {
212 log::trace!("WG: Sending response packet");
213 let pkt: Packet = response.into();
215 let data = BytesMut::from(pkt.as_bytes());
216 let socket = self.udp_socket.clone();
217 let endpoint = self.peer_endpoint;
218 tokio::spawn(async move {
219 if let Err(e) = socket.send_to(&data, endpoint).await {
220 log::error!("Failed to send response: {}", e);
221 }
222 });
223
224 for queued in tunn.get_queued_packets() {
226 let pkt: Packet = queued.into();
227 let data = BytesMut::from(pkt.as_bytes());
228 let socket = self.udp_socket.clone();
229 let endpoint = self.peer_endpoint;
230 tokio::spawn(async move {
231 if let Err(e) = socket.send_to(&data, endpoint).await {
232 log::error!("Failed to send queued packet: {}", e);
233 }
234 });
235 }
236
237 None
238 }
239 TunnResult::WriteToTunnel(decrypted) => {
240 if decrypted.is_empty() {
241 log::trace!("WG: Received keepalive");
242 return None;
243 }
244 let bytes = BytesMut::from(decrypted.as_bytes());
245 log::trace!("WG: Decrypted {} bytes", bytes.len());
246 Some(bytes)
247 }
248 }
249 }
250
251 pub async fn run_receive_loop(self: &Arc<Self>) -> Result<()> {
253 let mut buf = vec![0u8; 65535];
254
255 loop {
256 match self.udp_socket.recv_from(&mut buf).await {
257 Ok((len, from)) => {
258 if from != self.peer_endpoint {
259 log::warn!("Received packet from unknown peer: {}", from);
260 continue;
261 }
262
263 log::trace!("Received UDP packet ({} bytes) from {}", len, from);
264
265 if let Some(ip_packet) = self.process_incoming_udp(&buf[..len]) {
266 if self.incoming_tx.send(ip_packet).await.is_err() {
267 log::error!("Incoming channel closed");
268 break;
269 }
270 }
271 }
272 Err(e) => {
273 log::error!("UDP receive error: {}", e);
274 break;
275 }
276 }
277 }
278
279 Ok(())
280 }
281
282 pub async fn run_send_loop(self: &Arc<Self>) -> Result<()> {
284 let mut outgoing_rx = self.outgoing_rx.lock().await;
285
286 while let Some(packet) = outgoing_rx.recv().await {
287 if let Err(e) = self.send_ip_packet(packet).await {
288 log::error!("Failed to send packet: {}", e);
289 }
290 }
291
292 Ok(())
293 }
294
295 pub async fn run_timer_loop(self: &Arc<Self>) -> Result<()> {
297 let mut interval = tokio::time::interval(Duration::from_millis(250));
298
299 loop {
300 interval.tick().await;
301
302 let packets_to_send: Vec<Vec<u8>> = {
303 let mut tunn = self.tunn.lock();
304 match tunn.update_timers() {
305 Ok(Some(packet)) => {
306 let pkt: Packet = packet.into();
307 vec![pkt.as_bytes().to_vec()]
308 }
309 Ok(None) => vec![],
310 Err(e) => {
311 log::trace!("Timer error (may be normal): {:?}", e);
312 vec![]
313 }
314 }
315 };
316
317 for packet in packets_to_send {
318 if let Err(e) = self.udp_socket.send_to(&packet, self.peer_endpoint).await {
319 log::error!("Failed to send timer packet: {}", e);
320 }
321 }
322 }
323 }
324
325 pub async fn wait_for_handshake(&self, timeout_duration: Duration) -> Result<()> {
327 let start = std::time::Instant::now();
328
329 loop {
330 {
331 let tunn = self.tunn.lock();
332 let (time_since_handshake, _tx_bytes, _rx_bytes, _, _) = tunn.stats();
334 if time_since_handshake.is_some() {
335 log::info!("WireGuard handshake completed!");
336 return Ok(());
337 }
338 }
339
340 if start.elapsed() > timeout_duration {
341 return Err(Error::HandshakeTimeout(timeout_duration));
342 }
343
344 tokio::time::sleep(Duration::from_millis(50)).await;
345 }
346 }
347}