Skip to main content

aivpn_server/
nat.rs

1//! NAT Forwarder Module
2//!
3//! Handles:
4//! - TUN device creation
5//! - Packet forwarding to internet
6//! - NAT masquerading
7
8use std::io;
9use std::sync::Arc;
10use tokio::io::AsyncWriteExt;
11use tokio::sync::Mutex;
12#[cfg(target_os = "linux")]
13use tracing::warn;
14use tracing::{debug, info};
15
16use aivpn_common::error::{Error, Result};
17use aivpn_common::network_config::VpnNetworkConfig;
18
19const TUN_MTU: u16 = 1420;
20
21/// NAT Forwarder for routing traffic to internet
22/// Uses split reader/writer to avoid mutex starvation
23pub struct NatForwarder {
24    tun_name: String,
25    tun_addr: String,
26    tun_netmask: String,
27    network_config: VpnNetworkConfig,
28    writer: Option<Arc<Mutex<tun::DeviceWriter>>>,
29    writer_taken: Option<Mutex<Option<tun::DeviceWriter>>>,
30    reader: Option<Mutex<Option<tun::DeviceReader>>>,
31}
32
33impl NatForwarder {
34    pub fn new(
35        tun_name: &str,
36        tun_addr: &str,
37        tun_netmask: &str,
38        network_config: VpnNetworkConfig,
39    ) -> Result<Self> {
40        Ok(Self {
41            tun_name: tun_name.to_string(),
42            tun_addr: tun_addr.to_string(),
43            tun_netmask: tun_netmask.to_string(),
44            network_config,
45            writer: None,
46            writer_taken: None,
47            reader: None,
48        })
49    }
50
51    /// Create TUN device for NAT
52    pub fn create(&mut self) -> Result<()> {
53        let mut config = tun::Configuration::default();
54
55        config
56            .tun_name(&self.tun_name)
57            .address(&self.tun_addr)
58            .netmask(&self.tun_netmask)
59            .mtu(TUN_MTU)
60            .up();
61
62        #[cfg(target_os = "linux")]
63        config.platform_config(|config| {
64            config.ensure_root_privileges(true);
65        });
66
67        let dev = tun::create_as_async(&config)
68            .map_err(|e| Error::Io(io::Error::new(io::ErrorKind::Other, e.to_string())))?;
69
70        let (writer, reader) = dev
71            .split()
72            .map_err(|e| Error::Io(io::Error::new(io::ErrorKind::Other, e.to_string())))?;
73        self.writer = None; // Writer accessed via take_writer() for channel-based I/O
74        self.writer_taken = Some(Mutex::new(Some(writer)));
75        self.reader = Some(Mutex::new(Some(reader)));
76
77        info!(
78            "Created NAT TUN device: {} ({}/{}, subnet {})",
79            self.tun_name,
80            self.tun_addr,
81            self.tun_netmask,
82            self.network_config.cidr_string(),
83        );
84
85        // Enable IP forwarding (Linux)
86        #[cfg(target_os = "linux")]
87        {
88            self.enable_ip_forwarding()?;
89            self.setup_iptables()?;
90        }
91
92        Ok(())
93    }
94
95    /// Enable IP forwarding on Linux
96    #[cfg(target_os = "linux")]
97    fn enable_ip_forwarding(&self) -> Result<()> {
98        use std::fs::{read_to_string, write};
99
100        // Check if already enabled (e.g. inside Docker with host sysctl)
101        if let Ok(val) = read_to_string("/proc/sys/net/ipv4/ip_forward") {
102            if val.trim() == "1" {
103                info!("IPv4 forwarding already enabled");
104                return Ok(());
105            }
106        }
107
108        // Try to enable IPv4 forwarding
109        write("/proc/sys/net/ipv4/ip_forward", "1").map_err(|e| {
110            Error::Io(io::Error::new(
111                io::ErrorKind::PermissionDenied,
112                format!("Failed to enable IP forwarding: {}", e),
113            ))
114        })?;
115
116        info!("Enabled IPv4 forwarding");
117        Ok(())
118    }
119
120    /// Setup iptables rules for NAT
121    #[cfg(target_os = "linux")]
122    fn setup_iptables(&self) -> Result<()> {
123        use std::process::Command;
124
125        // Enable NAT masquerading
126        let output = Command::new("iptables")
127            .args([
128                "-t",
129                "nat",
130                "-A",
131                "POSTROUTING",
132                "-s",
133                &self.network_config.cidr_string(),
134                "-j",
135                "MASQUERADE",
136            ])
137            .output();
138
139        match output {
140            Ok(out) => {
141                if out.status.success() {
142                    info!("Added iptables MASQUERADE rule");
143                } else {
144                    let stderr = String::from_utf8_lossy(&out.stderr);
145                    warn!("iptables rule failed: {}", stderr);
146                }
147            }
148            Err(e) => {
149                warn!("iptables command not found: {}", e);
150            }
151        }
152
153        // Allow forwarding
154        let _ = Command::new("iptables")
155            .args(["-A", "FORWARD", "-i", &self.tun_name, "-j", "ACCEPT"])
156            .output();
157
158        let _ = Command::new("iptables")
159            .args([
160                "-A",
161                "FORWARD",
162                "-o",
163                &self.tun_name,
164                "-m",
165                "state",
166                "--state",
167                "RELATED,ESTABLISHED",
168                "-j",
169                "ACCEPT",
170            ])
171            .output();
172
173        // Clamp TCP MSS across the TUN boundary to avoid PMTU blackholes
174        // on download-heavy flows when the VPN MTU is lower than the uplink MTU.
175        let _ = Command::new("iptables")
176            .args([
177                "-t",
178                "mangle",
179                "-A",
180                "FORWARD",
181                "-o",
182                &self.tun_name,
183                "-p",
184                "tcp",
185                "--tcp-flags",
186                "SYN,RST",
187                "SYN",
188                "-j",
189                "TCPMSS",
190                "--clamp-mss-to-pmtu",
191            ])
192            .output();
193
194        let _ = Command::new("iptables")
195            .args([
196                "-t",
197                "mangle",
198                "-A",
199                "FORWARD",
200                "-i",
201                &self.tun_name,
202                "-p",
203                "tcp",
204                "--tcp-flags",
205                "SYN,RST",
206                "SYN",
207                "-j",
208                "TCPMSS",
209                "--clamp-mss-to-pmtu",
210            ])
211            .output();
212
213        Ok(())
214    }
215
216    /// Forward packet to TUN (write)
217    pub async fn forward_packet(&self, packet: &[u8]) -> Result<()> {
218        let writer = self.writer.as_ref().ok_or_else(|| {
219            Error::Io(io::Error::new(
220                io::ErrorKind::NotConnected,
221                "TUN device not created",
222            ))
223        })?;
224
225        let mut w = writer.lock().await;
226
227        // Linux TUN with IFF_NO_PI (default) expects raw IP packets
228        // No flush() — let the OS buffer writes naturally for throughput
229        w.write_all(packet).await?;
230
231        debug!("Forwarded {} bytes to TUN", packet.len());
232        Ok(())
233    }
234
235    /// Take ownership of the TUN writer (for use in a dedicated writer task)
236    pub async fn take_writer(&self) -> Option<tun::DeviceWriter> {
237        if let Some(ref lock) = self.writer_taken {
238            lock.lock().await.take()
239        } else {
240            None
241        }
242    }
243
244    /// Take ownership of the TUN reader (for use in a spawned task)
245    pub async fn take_reader(&self) -> Option<tun::DeviceReader> {
246        if let Some(reader_lock) = &self.reader {
247            reader_lock.lock().await.take()
248        } else {
249            None
250        }
251    }
252
253    /// Get TUN device name
254    pub fn tun_name(&self) -> &str {
255        &self.tun_name
256    }
257}
258
259impl Drop for NatForwarder {
260    fn drop(&mut self) {
261        if self.writer.is_some() {
262            info!("Closing NAT TUN device: {}", self.tun_name);
263        }
264
265        // Cleanup iptables (optional, rules persist)
266        #[cfg(target_os = "linux")]
267        {
268            use std::process::Command;
269            let _ = Command::new("iptables")
270                .args([
271                    "-t",
272                    "nat",
273                    "-D",
274                    "POSTROUTING",
275                    "-s",
276                    &self.network_config.cidr_string(),
277                    "-j",
278                    "MASQUERADE",
279                ])
280                .output();
281
282            let _ = Command::new("iptables")
283                .args([
284                    "-t",
285                    "mangle",
286                    "-D",
287                    "FORWARD",
288                    "-o",
289                    &self.tun_name,
290                    "-p",
291                    "tcp",
292                    "--tcp-flags",
293                    "SYN,RST",
294                    "SYN",
295                    "-j",
296                    "TCPMSS",
297                    "--clamp-mss-to-pmtu",
298                ])
299                .output();
300
301            let _ = Command::new("iptables")
302                .args([
303                    "-t",
304                    "mangle",
305                    "-D",
306                    "FORWARD",
307                    "-i",
308                    &self.tun_name,
309                    "-p",
310                    "tcp",
311                    "--tcp-flags",
312                    "SYN,RST",
313                    "SYN",
314                    "-j",
315                    "TCPMSS",
316                    "--clamp-mss-to-pmtu",
317                ])
318                .output();
319        }
320    }
321}