Skip to main content

fips_core/gateway/
nat.rs

1//! NAT rule management.
2//!
3//! Manages nftables DNAT/SNAT rules via the rustables netlink API
4//! for translating between virtual IPs and FIPS mesh addresses.
5
6use std::collections::HashMap;
7use std::net::Ipv6Addr;
8use tracing::{debug, info};
9
10use rustables::expr::{
11    Cmp, CmpOp, HighLevelPayload, IPv6HeaderField, Immediate, Masquerade, Meta, MetaType, Nat,
12    NatType, NetworkHeaderField, Register, TCPHeaderField, TransportHeaderField, UDPHeaderField,
13};
14use rustables::{Batch, Chain, ChainType, Hook, HookClass, MsgType, ProtocolFamily, Rule, Table};
15
16use crate::config::{PortForward, Proto};
17
18const TABLE_NAME: &str = "fips_gateway";
19const PREROUTING_CHAIN: &str = "prerouting";
20const POSTROUTING_CHAIN: &str = "postrouting";
21
22/// NAT priority constants (matching nftables standard priorities).
23const DSTNAT_PRIORITY: i32 = -100;
24const SRCNAT_PRIORITY: i32 = 100;
25
26/// Errors from NAT operations.
27#[derive(Debug, thiserror::Error)]
28pub enum NatError {
29    #[error("nftables error: {0}")]
30    Nftables(String),
31    #[error("rule not found for virtual IP {0}")]
32    RuleNotFound(Ipv6Addr),
33}
34
35impl From<rustables::error::QueryError> for NatError {
36    fn from(e: rustables::error::QueryError) -> Self {
37        NatError::Nftables(e.to_string())
38    }
39}
40
41impl From<rustables::error::BuilderError> for NatError {
42    fn from(e: rustables::error::BuilderError) -> Self {
43        NatError::Nftables(e.to_string())
44    }
45}
46
47/// A virtual IP ↔ mesh address mapping for NAT rule generation.
48#[derive(Clone)]
49struct NatMapping {
50    virtual_ip: Ipv6Addr,
51    mesh_addr: Ipv6Addr,
52}
53
54/// NAT rule manager using nftables via rustables netlink API.
55///
56/// Rebuilds the entire nftables table atomically on every change to
57/// avoid relying on kernel rule handle tracking (which rustables
58/// doesn't expose). The table is small (one masquerade + two rules
59/// per mapping) so this is cheap.
60pub struct NatManager {
61    table: Table,
62    pre_chain: Chain,
63    post_chain: Chain,
64    /// LAN interface name, used to gate the port-forward LAN-side
65    /// masquerade rule (distinct from the fips0 egress masquerade).
66    lan_interface: String,
67    /// Active mappings keyed by virtual IP.
68    mappings: HashMap<Ipv6Addr, NatMapping>,
69    /// Inbound port-forward rules (TASK-2026-0061).
70    port_forwards: Vec<PortForward>,
71}
72
73impl NatManager {
74    /// Create the nftables table and NAT chains.
75    ///
76    /// Installs a masquerade rule for traffic exiting via `fips0` so that
77    /// LAN client source addresses are rewritten to the gateway's mesh
78    /// address, allowing return traffic to route back through the mesh.
79    ///
80    /// `lan_interface` is the gateway's LAN-facing interface name,
81    /// needed by the port-forward LAN-side masquerade rule.
82    pub fn new(lan_interface: String) -> Result<Self, NatError> {
83        let table = Table::new(ProtocolFamily::Inet).with_name(TABLE_NAME);
84        let pre_chain = Chain::new(&table)
85            .with_name(PREROUTING_CHAIN)
86            .with_type(ChainType::Nat)
87            .with_hook(Hook::new(HookClass::PreRouting, DSTNAT_PRIORITY));
88        let post_chain = Chain::new(&table)
89            .with_name(POSTROUTING_CHAIN)
90            .with_type(ChainType::Nat)
91            .with_hook(Hook::new(HookClass::PostRouting, SRCNAT_PRIORITY));
92
93        let mgr = Self {
94            table,
95            pre_chain,
96            post_chain,
97            lan_interface,
98            mappings: HashMap::new(),
99            port_forwards: Vec::new(),
100        };
101        mgr.rebuild()?;
102
103        info!("Created nftables table '{TABLE_NAME}' with NAT chains and fips0 masquerade");
104        Ok(mgr)
105    }
106
107    /// Replace the current inbound port-forward rule set and rebuild
108    /// the nftables table atomically. Pass an empty slice to clear.
109    pub fn set_port_forwards(&mut self, forwards: &[PortForward]) -> Result<(), NatError> {
110        self.port_forwards = forwards.to_vec();
111        self.rebuild()?;
112        info!(
113            count = self.port_forwards.len(),
114            "Applied inbound port forwards"
115        );
116        Ok(())
117    }
118
119    /// Add DNAT and SNAT rules for a virtual IP ↔ mesh address mapping.
120    pub fn add_mapping(
121        &mut self,
122        virtual_ip: Ipv6Addr,
123        mesh_addr: Ipv6Addr,
124    ) -> Result<(), NatError> {
125        self.mappings.insert(
126            virtual_ip,
127            NatMapping {
128                virtual_ip,
129                mesh_addr,
130            },
131        );
132        self.rebuild()?;
133
134        debug!(
135            virtual_ip = %virtual_ip,
136            mesh_addr = %mesh_addr,
137            "Added DNAT/SNAT rules"
138        );
139        Ok(())
140    }
141
142    /// Remove DNAT and SNAT rules for a virtual IP mapping.
143    pub fn remove_mapping(&mut self, virtual_ip: Ipv6Addr) -> Result<(), NatError> {
144        if self.mappings.remove(&virtual_ip).is_none() {
145            return Err(NatError::RuleNotFound(virtual_ip));
146        }
147        self.rebuild()?;
148
149        debug!(virtual_ip = %virtual_ip, "Removed DNAT/SNAT rules");
150        Ok(())
151    }
152
153    /// Flush all rules and delete the nftables table.
154    pub fn cleanup(self) -> Result<(), NatError> {
155        let mut batch = Batch::new();
156        batch.add(&self.table, MsgType::Del);
157        batch
158            .send()
159            .map_err(|e| NatError::Nftables(e.to_string()))?;
160
161        info!("Deleted nftables table '{TABLE_NAME}'");
162        Ok(())
163    }
164
165    /// Number of active NAT mappings.
166    pub fn mapping_count(&self) -> usize {
167        self.mappings.len()
168    }
169
170    /// Atomically rebuild the entire nftables table with all current
171    /// rules. Deletes and recreates the table, chains, masquerade rule,
172    /// and all per-mapping DNAT/SNAT rules in a single netlink batch.
173    fn rebuild(&self) -> Result<(), NatError> {
174        // Delete existing table in a separate batch — ignore ENOENT on
175        // first call when the table doesn't exist yet.
176        let mut del_batch = Batch::new();
177        del_batch.add(&self.table, MsgType::Del);
178        let _ = del_batch.send();
179
180        // Recreate table, chains, and all rules atomically.
181        let mut batch = Batch::new();
182        batch.add(&self.table, MsgType::Add);
183        batch.add(&self.pre_chain, MsgType::Add);
184        batch.add(&self.post_chain, MsgType::Add);
185
186        // Masquerade rule: rewrite source address for traffic exiting fips0.
187        // Without this, LAN clients' source addresses (e.g. fd02::20) are
188        // not routable on the mesh, so return traffic would be black-holed.
189        let masq_rule = Rule::new(&self.post_chain)?
190            .with_expr(Meta::new(MetaType::OifName))
191            .with_expr(Cmp::new(CmpOp::Eq, b"fips0\0".to_vec()))
192            .with_expr(Masquerade::default());
193        batch.add(&masq_rule, MsgType::Add);
194
195        // Per-mapping DNAT/SNAT rules.
196        for mapping in self.mappings.values() {
197            let dnat_rule = Rule::new(&self.pre_chain)?
198                .with_expr(Meta::new(MetaType::NfProto))
199                .with_expr(Cmp::new(CmpOp::Eq, [libc::NFPROTO_IPV6 as u8]))
200                .with_expr(
201                    HighLevelPayload::Network(NetworkHeaderField::IPv6(IPv6HeaderField::Daddr))
202                        .build(),
203                )
204                .with_expr(Cmp::new(CmpOp::Eq, mapping.virtual_ip.octets()))
205                .with_expr(Immediate::new_data(
206                    mapping.mesh_addr.octets().to_vec(),
207                    Register::Reg1,
208                ))
209                .with_expr(
210                    Nat::default()
211                        .with_nat_type(NatType::DNat)
212                        .with_family(ProtocolFamily::Ipv6)
213                        .with_ip_register(Register::Reg1),
214                );
215            batch.add(&dnat_rule, MsgType::Add);
216
217            let snat_rule = Rule::new(&self.post_chain)?
218                .with_expr(Meta::new(MetaType::NfProto))
219                .with_expr(Cmp::new(CmpOp::Eq, [libc::NFPROTO_IPV6 as u8]))
220                .with_expr(
221                    HighLevelPayload::Network(NetworkHeaderField::IPv6(IPv6HeaderField::Saddr))
222                        .build(),
223                )
224                .with_expr(Cmp::new(CmpOp::Eq, mapping.mesh_addr.octets()))
225                .with_expr(Immediate::new_data(
226                    mapping.virtual_ip.octets().to_vec(),
227                    Register::Reg1,
228                ))
229                .with_expr(
230                    Nat::default()
231                        .with_nat_type(NatType::SNat)
232                        .with_family(ProtocolFamily::Ipv6)
233                        .with_ip_register(Register::Reg1),
234                );
235            batch.add(&snat_rule, MsgType::Add);
236        }
237
238        // Inbound port-forward rules (TASK-2026-0061). Each forward is
239        // one DNAT rule in prerouting keyed on (iif fips0, nfproto ipv6,
240        // l4proto, th dport). When any forwards are configured, emit a
241        // single LAN-side masquerade in postrouting so the LAN target
242        // host sees the gateway's LAN address as source and replies
243        // flow back through conntrack.
244        for pf in &self.port_forwards {
245            let l4proto: u8 = match pf.proto {
246                Proto::Tcp => libc::IPPROTO_TCP as u8,
247                Proto::Udp => libc::IPPROTO_UDP as u8,
248            };
249            let dport_field = match pf.proto {
250                Proto::Tcp => TransportHeaderField::Tcp(TCPHeaderField::Dport),
251                Proto::Udp => TransportHeaderField::Udp(UDPHeaderField::Dport),
252            };
253            let target_ip = *pf.target.ip();
254            let target_port_be = pf.target.port().to_be_bytes();
255
256            let dnat_rule = Rule::new(&self.pre_chain)?
257                .with_expr(Meta::new(MetaType::IifName))
258                .with_expr(Cmp::new(CmpOp::Eq, b"fips0\0".to_vec()))
259                .with_expr(Meta::new(MetaType::NfProto))
260                .with_expr(Cmp::new(CmpOp::Eq, [libc::NFPROTO_IPV6 as u8]))
261                .with_expr(Meta::new(MetaType::L4Proto))
262                .with_expr(Cmp::new(CmpOp::Eq, [l4proto]))
263                .with_expr(HighLevelPayload::Transport(dport_field).build())
264                .with_expr(Cmp::new(CmpOp::Eq, pf.listen_port.to_be_bytes().to_vec()))
265                .with_expr(Immediate::new_data(
266                    target_ip.octets().to_vec(),
267                    Register::Reg1,
268                ))
269                .with_expr(Immediate::new_data(target_port_be.to_vec(), Register::Reg2))
270                .with_expr(
271                    Nat::default()
272                        .with_nat_type(NatType::DNat)
273                        .with_family(ProtocolFamily::Ipv6)
274                        .with_ip_register(Register::Reg1)
275                        .with_port_register(Register::Reg2),
276                );
277            batch.add(&dnat_rule, MsgType::Add);
278        }
279
280        if !self.port_forwards.is_empty() {
281            let mut lan_iface = self.lan_interface.clone().into_bytes();
282            lan_iface.push(0);
283            let lan_masq = Rule::new(&self.post_chain)?
284                .with_expr(Meta::new(MetaType::IifName))
285                .with_expr(Cmp::new(CmpOp::Eq, b"fips0\0".to_vec()))
286                .with_expr(Meta::new(MetaType::OifName))
287                .with_expr(Cmp::new(CmpOp::Eq, lan_iface))
288                .with_expr(Meta::new(MetaType::NfProto))
289                .with_expr(Cmp::new(CmpOp::Eq, [libc::NFPROTO_IPV6 as u8]))
290                .with_expr(Masquerade::default());
291            batch.add(&lan_masq, MsgType::Add);
292        }
293
294        batch
295            .send()
296            .map_err(|e| NatError::Nftables(e.to_string()))?;
297        Ok(())
298    }
299}