1use async_backtrace::frame;
13use async_backtrace::framed;
14use neli::consts::nl::*;
15use neli::consts::rtnl::*;
16use neli::consts::socket::*;
17use neli::nl::NlPayload;
18use neli::nl::Nlmsghdr;
19use neli::router::asynchronous::NlRouter;
20use neli::rtnl::*;
21use neli::types::RtBuffer;
22use neli::utils::Groups;
23use std::mem;
24use std::net::IpAddr;
25use std::net::Ipv4Addr;
26use tokio::sync::mpsc;
27use tokio::sync::oneshot;
28use tracing::Instrument;
29use tracing::debug;
30use tracing::error;
31use tracing::error_span;
32use tracing::info;
33
34pub enum Firewall {
36 DropTraffic {
37 ips: Vec<Ipv4Addr>,
38 tx: oneshot::Sender<()>,
39 },
40 TurnOffRules {
41 tx: oneshot::Sender<()>,
42 },
43}
44
45pub trait FirewallExt {
48 fn drop_traffic(&self, ips: Vec<Ipv4Addr>) -> impl Future<Output = ()>;
50
51 fn turn_off_rules(&self) -> impl Future<Output = ()>;
53}
54
55impl FirewallExt for mpsc::Sender<Firewall> {
56 #[framed]
57 async fn drop_traffic(&self, ips: Vec<Ipv4Addr>) {
58 let (tx, rx) = oneshot::channel();
59 self.send(Firewall::DropTraffic { ips, tx })
60 .await
61 .expect("FirewallExt::drop_traffic: internal actor should receive request");
62 rx.await
63 .expect("FirewallExt::drop_traffic: internal actor should send response")
64 }
65
66 #[framed]
67 async fn turn_off_rules(&self) {
68 let (tx, rx) = oneshot::channel();
69 self.send(Firewall::TurnOffRules { tx })
70 .await
71 .expect("FirewallExt::turn_off_rules: internal actor should receive request");
72 rx.await
73 .expect("FirewallExt::turn_off_rules: internal actor should send response")
74 }
75}
76
77#[framed]
79pub async fn new() -> mpsc::Sender<Firewall> {
80 let (tx, mut rx) = mpsc::channel(10);
81
82 tokio::spawn(
83 frame!(async move {
84 debug!("starting");
85
86 let (socket, _) = NlRouter::connect(NlFamily::Route, None, Groups::empty())
87 .await
88 .unwrap();
89
90 let mut disabled_ips = Vec::new();
91
92 while let Some(msg) = rx.recv().await {
93 process(msg, &socket, &mut disabled_ips).await;
94 }
95
96 debug!("finished");
97 })
98 .instrument(error_span!("firewall")),
99 );
100
101 tx
102}
103
104#[framed]
105async fn process(msg: Firewall, socket: &NlRouter, disabled_ips: &mut Vec<Ipv4Addr>) {
106 match msg {
107 Firewall::DropTraffic { ips, tx } => {
108 info!("Removing rules for: {disabled_ips:?}");
109 turn_off_rules(socket, mem::take(disabled_ips)).await;
110 *disabled_ips = ips;
111 info!("Adding rules for: {disabled_ips:?}");
112 drop_traffic(socket, disabled_ips).await;
113 if let Err(err) = log_routes(socket).await {
114 error!("Failed to list routes: {err}");
115 }
116 tx.send(())
117 .expect("process Firewall::DropTraffic: failed to send a response");
118 }
119
120 Firewall::TurnOffRules { tx } => {
121 info!("Removing rules for: {disabled_ips:?}");
122 turn_off_rules(socket, mem::take(disabled_ips)).await;
123 if let Err(err) = log_routes(socket).await {
124 error!("Failed to list routes: {err}");
125 }
126 tx.send(())
127 .expect("process Firewall::TurnOffRules: failed to send a response");
128 }
129 }
130}
131
132#[framed]
133async fn drop_traffic(socket: &NlRouter, ips: &[Ipv4Addr]) {
134 for ip in ips.iter() {
135 let Err(err) = add_unreachable_route(socket, ip).await else {
136 continue;
137 };
138 error!("Failed to add unreachable route for ip {ip}: {err}");
139 }
140}
141
142#[framed]
143async fn turn_off_rules(socket: &NlRouter, ips: Vec<Ipv4Addr>) {
144 for ip in ips.into_iter() {
145 let Err(err) = remove_unreachable_route(socket, ip).await else {
146 continue;
147 };
148 error!("Failed to remove unreachable route for ip {ip}: {err}");
149 }
150}
151
152async fn add_unreachable_route(socket: &NlRouter, ip: &Ipv4Addr) -> anyhow::Result<()> {
153 let mut attrs = RtBuffer::new();
154 attrs.push(
155 RtattrBuilder::default()
156 .rta_type(Rta::Dst)
157 .rta_payload(ip.octets())
158 .build()?,
159 );
160 let rtmsg = RtmsgBuilder::default()
161 .rtm_family(RtAddrFamily::Inet)
162 .rtm_dst_len(32)
163 .rtm_src_len(0)
164 .rtm_tos(0)
165 .rtm_table(RtTable::Main)
166 .rtm_protocol(Rtprot::Unspec)
167 .rtm_scope(RtScope::Universe)
168 .rtm_type(Rtn::Blackhole)
169 .rtattrs(attrs)
170 .build()?;
171 socket
172 .send::<Rtm, Rtmsg, NlTypeWrapper, Rtmsg>(
173 Rtm::Newroute,
174 NlmF::REQUEST | NlmF::CREATE | NlmF::REPLACE,
175 NlPayload::Payload(rtmsg),
176 )
177 .await?;
178 Ok(())
179}
180
181async fn remove_unreachable_route(socket: &NlRouter, ip: Ipv4Addr) -> anyhow::Result<()> {
182 let mut attrs = RtBuffer::new();
183 attrs.push(
184 RtattrBuilder::default()
185 .rta_type(Rta::Dst)
186 .rta_payload(ip.octets())
187 .build()?,
188 );
189 let rtmsg = RtmsgBuilder::default()
190 .rtm_family(RtAddrFamily::Inet)
191 .rtm_dst_len(32)
192 .rtm_src_len(0)
193 .rtm_tos(0)
194 .rtm_table(RtTable::Main)
195 .rtm_protocol(Rtprot::Unspec)
196 .rtm_scope(RtScope::Universe)
197 .rtm_type(Rtn::Blackhole)
198 .rtattrs(attrs)
199 .build()?;
200 socket
201 .send::<Rtm, Rtmsg, NlTypeWrapper, Rtmsg>(
202 Rtm::Delroute,
203 NlmF::REQUEST,
204 NlPayload::Payload(rtmsg),
205 )
206 .await?;
207 Ok(())
208}
209
210async fn log_routes(socket: &NlRouter) -> anyhow::Result<()> {
211 let rtmsg = RtmsgBuilder::default()
212 .rtm_family(RtAddrFamily::Inet)
213 .rtm_dst_len(0)
214 .rtm_src_len(0)
215 .rtm_tos(0)
216 .rtm_table(RtTable::Unspec)
217 .rtm_protocol(Rtprot::Unspec)
218 .rtm_scope(RtScope::Universe)
219 .rtm_type(Rtn::Unspec)
220 .build()?;
221 let mut recv = socket
222 .send::<Rtm, Rtmsg, NlTypeWrapper, Rtmsg>(
223 Rtm::Getroute,
224 NlmF::DUMP,
225 NlPayload::Payload(rtmsg),
226 )
227 .await?;
228
229 while let Some(rtm_result) = recv.next().await {
230 let rtm = rtm_result?;
231 if let NlTypeWrapper::Rtm(_) = rtm.nl_type() {
232 parse_route_table(rtm)?;
233 }
234 }
235
236 Ok(())
237}
238
239fn parse_route_table(rtm: Nlmsghdr<NlTypeWrapper, Rtmsg>) -> anyhow::Result<()> {
240 if let Some(payload) = rtm.get_payload() {
241 let mut dst = None;
242
243 for attr in payload.rtattrs().iter() {
244 fn to_addr(b: &[u8]) -> Option<IpAddr> {
245 if let Ok(tup) = <&[u8; 4]>::try_from(b) {
246 Some(IpAddr::from(*tup))
247 } else if let Ok(tup) = <&[u8; 16]>::try_from(b) {
248 Some(IpAddr::from(*tup))
249 } else {
250 None
251 }
252 }
253
254 if attr.rta_type() == &Rta::Dst {
255 dst = to_addr(attr.rta_payload().as_ref())
256 }
257 }
258
259 let dst = if let Some(dst) = dst {
260 format!("{}/{} ", dst, payload.rtm_dst_len())
261 } else {
262 "default".to_string()
263 };
264
265 info!(
266 "active route for {:?}: {dst}: {:?}",
267 payload.rtm_table(),
268 payload.rtm_type()
269 );
270 }
271 Ok(())
272}