1use std::collections::HashMap;
7use std::net::Ipv6Addr;
8use tracing::{debug, info};
9
10use rustables::expr::{
11 Cmp, CmpOp, HighLevelPayload, IPv6HeaderField, Immediate, Masquerade, Meta, MetaType, Nat,
12 NatType, NetworkHeaderField, Register, TCPHeaderField, TransportHeaderField, UDPHeaderField,
13};
14use rustables::{Batch, Chain, ChainType, Hook, HookClass, MsgType, ProtocolFamily, Rule, Table};
15
16use crate::config::{PortForward, Proto};
17
18const TABLE_NAME: &str = "fips_gateway";
19const PREROUTING_CHAIN: &str = "prerouting";
20const POSTROUTING_CHAIN: &str = "postrouting";
21
22const DSTNAT_PRIORITY: i32 = -100;
24const SRCNAT_PRIORITY: i32 = 100;
25
26#[derive(Debug, thiserror::Error)]
28pub enum NatError {
29 #[error("nftables error: {0}")]
30 Nftables(String),
31 #[error("rule not found for virtual IP {0}")]
32 RuleNotFound(Ipv6Addr),
33}
34
35impl From<rustables::error::QueryError> for NatError {
36 fn from(e: rustables::error::QueryError) -> Self {
37 NatError::Nftables(e.to_string())
38 }
39}
40
41impl From<rustables::error::BuilderError> for NatError {
42 fn from(e: rustables::error::BuilderError) -> Self {
43 NatError::Nftables(e.to_string())
44 }
45}
46
47#[derive(Clone)]
49struct NatMapping {
50 virtual_ip: Ipv6Addr,
51 mesh_addr: Ipv6Addr,
52}
53
54pub struct NatManager {
61 table: Table,
62 pre_chain: Chain,
63 post_chain: Chain,
64 lan_interface: String,
67 mappings: HashMap<Ipv6Addr, NatMapping>,
69 port_forwards: Vec<PortForward>,
71}
72
73impl NatManager {
74 pub fn new(lan_interface: String) -> Result<Self, NatError> {
83 let table = Table::new(ProtocolFamily::Inet).with_name(TABLE_NAME);
84 let pre_chain = Chain::new(&table)
85 .with_name(PREROUTING_CHAIN)
86 .with_type(ChainType::Nat)
87 .with_hook(Hook::new(HookClass::PreRouting, DSTNAT_PRIORITY));
88 let post_chain = Chain::new(&table)
89 .with_name(POSTROUTING_CHAIN)
90 .with_type(ChainType::Nat)
91 .with_hook(Hook::new(HookClass::PostRouting, SRCNAT_PRIORITY));
92
93 let mgr = Self {
94 table,
95 pre_chain,
96 post_chain,
97 lan_interface,
98 mappings: HashMap::new(),
99 port_forwards: Vec::new(),
100 };
101 mgr.rebuild()?;
102
103 info!("Created nftables table '{TABLE_NAME}' with NAT chains and fips0 masquerade");
104 Ok(mgr)
105 }
106
107 pub fn set_port_forwards(&mut self, forwards: &[PortForward]) -> Result<(), NatError> {
110 self.port_forwards = forwards.to_vec();
111 self.rebuild()?;
112 info!(
113 count = self.port_forwards.len(),
114 "Applied inbound port forwards"
115 );
116 Ok(())
117 }
118
119 pub fn add_mapping(
121 &mut self,
122 virtual_ip: Ipv6Addr,
123 mesh_addr: Ipv6Addr,
124 ) -> Result<(), NatError> {
125 self.mappings.insert(
126 virtual_ip,
127 NatMapping {
128 virtual_ip,
129 mesh_addr,
130 },
131 );
132 self.rebuild()?;
133
134 debug!(
135 virtual_ip = %virtual_ip,
136 mesh_addr = %mesh_addr,
137 "Added DNAT/SNAT rules"
138 );
139 Ok(())
140 }
141
142 pub fn remove_mapping(&mut self, virtual_ip: Ipv6Addr) -> Result<(), NatError> {
144 if self.mappings.remove(&virtual_ip).is_none() {
145 return Err(NatError::RuleNotFound(virtual_ip));
146 }
147 self.rebuild()?;
148
149 debug!(virtual_ip = %virtual_ip, "Removed DNAT/SNAT rules");
150 Ok(())
151 }
152
153 pub fn cleanup(self) -> Result<(), NatError> {
155 let mut batch = Batch::new();
156 batch.add(&self.table, MsgType::Del);
157 batch
158 .send()
159 .map_err(|e| NatError::Nftables(e.to_string()))?;
160
161 info!("Deleted nftables table '{TABLE_NAME}'");
162 Ok(())
163 }
164
165 pub fn mapping_count(&self) -> usize {
167 self.mappings.len()
168 }
169
170 fn rebuild(&self) -> Result<(), NatError> {
174 let mut del_batch = Batch::new();
177 del_batch.add(&self.table, MsgType::Del);
178 let _ = del_batch.send();
179
180 let mut batch = Batch::new();
182 batch.add(&self.table, MsgType::Add);
183 batch.add(&self.pre_chain, MsgType::Add);
184 batch.add(&self.post_chain, MsgType::Add);
185
186 let masq_rule = Rule::new(&self.post_chain)?
190 .with_expr(Meta::new(MetaType::OifName))
191 .with_expr(Cmp::new(CmpOp::Eq, b"fips0\0".to_vec()))
192 .with_expr(Masquerade::default());
193 batch.add(&masq_rule, MsgType::Add);
194
195 for mapping in self.mappings.values() {
197 let dnat_rule = Rule::new(&self.pre_chain)?
198 .with_expr(Meta::new(MetaType::NfProto))
199 .with_expr(Cmp::new(CmpOp::Eq, [libc::NFPROTO_IPV6 as u8]))
200 .with_expr(
201 HighLevelPayload::Network(NetworkHeaderField::IPv6(IPv6HeaderField::Daddr))
202 .build(),
203 )
204 .with_expr(Cmp::new(CmpOp::Eq, mapping.virtual_ip.octets()))
205 .with_expr(Immediate::new_data(
206 mapping.mesh_addr.octets().to_vec(),
207 Register::Reg1,
208 ))
209 .with_expr(
210 Nat::default()
211 .with_nat_type(NatType::DNat)
212 .with_family(ProtocolFamily::Ipv6)
213 .with_ip_register(Register::Reg1),
214 );
215 batch.add(&dnat_rule, MsgType::Add);
216
217 let snat_rule = Rule::new(&self.post_chain)?
218 .with_expr(Meta::new(MetaType::NfProto))
219 .with_expr(Cmp::new(CmpOp::Eq, [libc::NFPROTO_IPV6 as u8]))
220 .with_expr(
221 HighLevelPayload::Network(NetworkHeaderField::IPv6(IPv6HeaderField::Saddr))
222 .build(),
223 )
224 .with_expr(Cmp::new(CmpOp::Eq, mapping.mesh_addr.octets()))
225 .with_expr(Immediate::new_data(
226 mapping.virtual_ip.octets().to_vec(),
227 Register::Reg1,
228 ))
229 .with_expr(
230 Nat::default()
231 .with_nat_type(NatType::SNat)
232 .with_family(ProtocolFamily::Ipv6)
233 .with_ip_register(Register::Reg1),
234 );
235 batch.add(&snat_rule, MsgType::Add);
236 }
237
238 for pf in &self.port_forwards {
245 let l4proto: u8 = match pf.proto {
246 Proto::Tcp => libc::IPPROTO_TCP as u8,
247 Proto::Udp => libc::IPPROTO_UDP as u8,
248 };
249 let dport_field = match pf.proto {
250 Proto::Tcp => TransportHeaderField::Tcp(TCPHeaderField::Dport),
251 Proto::Udp => TransportHeaderField::Udp(UDPHeaderField::Dport),
252 };
253 let target_ip = *pf.target.ip();
254 let target_port_be = pf.target.port().to_be_bytes();
255
256 let dnat_rule = Rule::new(&self.pre_chain)?
257 .with_expr(Meta::new(MetaType::IifName))
258 .with_expr(Cmp::new(CmpOp::Eq, b"fips0\0".to_vec()))
259 .with_expr(Meta::new(MetaType::NfProto))
260 .with_expr(Cmp::new(CmpOp::Eq, [libc::NFPROTO_IPV6 as u8]))
261 .with_expr(Meta::new(MetaType::L4Proto))
262 .with_expr(Cmp::new(CmpOp::Eq, [l4proto]))
263 .with_expr(HighLevelPayload::Transport(dport_field).build())
264 .with_expr(Cmp::new(CmpOp::Eq, pf.listen_port.to_be_bytes().to_vec()))
265 .with_expr(Immediate::new_data(
266 target_ip.octets().to_vec(),
267 Register::Reg1,
268 ))
269 .with_expr(Immediate::new_data(target_port_be.to_vec(), Register::Reg2))
270 .with_expr(
271 Nat::default()
272 .with_nat_type(NatType::DNat)
273 .with_family(ProtocolFamily::Ipv6)
274 .with_ip_register(Register::Reg1)
275 .with_port_register(Register::Reg2),
276 );
277 batch.add(&dnat_rule, MsgType::Add);
278 }
279
280 if !self.port_forwards.is_empty() {
281 let mut lan_iface = self.lan_interface.clone().into_bytes();
282 lan_iface.push(0);
283 let lan_masq = Rule::new(&self.post_chain)?
284 .with_expr(Meta::new(MetaType::IifName))
285 .with_expr(Cmp::new(CmpOp::Eq, b"fips0\0".to_vec()))
286 .with_expr(Meta::new(MetaType::OifName))
287 .with_expr(Cmp::new(CmpOp::Eq, lan_iface))
288 .with_expr(Meta::new(MetaType::NfProto))
289 .with_expr(Cmp::new(CmpOp::Eq, [libc::NFPROTO_IPV6 as u8]))
290 .with_expr(Masquerade::default());
291 batch.add(&lan_masq, MsgType::Add);
292 }
293
294 batch
295 .send()
296 .map_err(|e| NatError::Nftables(e.to_string()))?;
297 Ok(())
298 }
299}